Skip to content

Commit

Permalink
Treat correlation data as binary not a nullable string (#560)
Browse files Browse the repository at this point in the history
Co-authored-by: Bret Ambrose <bambrose@amazon.com>
  • Loading branch information
bretambrose and Bret Ambrose authored May 22, 2024
1 parent d76c3da commit b5c7972
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 17 deletions.
29 changes: 21 additions & 8 deletions awscrt/mqtt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0.
from typing import Any, Callable
from typing import Any, Callable, Union
import _awscrt
from concurrent.futures import Future
from enum import IntEnum
Expand Down Expand Up @@ -1104,8 +1104,9 @@ class PublishPacket:
message_expiry_interval_sec (int): Sent publishes - indicates the maximum amount of time allowed to elapse for message delivery before the server should instead delete the message (relative to a recipient). Received publishes - indicates the remaining amount of time (from the server's perspective) before the message would have been deleted relative to the subscribing client. If left None, indicates no expiration timeout.
topic_alias (int): An integer value that is used to identify the Topic instead of using the Topic Name. On outbound publishes, this will only be used if the outbound topic aliasing behavior has been set to Manual.
response_topic (str): Opaque topic string intended to assist with request/response implementations. Not internally meaningful to MQTT5 or this client.
correlation_data (Any): Opaque binary data used to correlate between publish messages, as a potential method for request-response implementation. Not internally meaningful to MQTT5.
subscription_identifiers (Sequence[int]): The subscription identifiers of all the subscriptions this message matched.
correlation_data (Optional[Union[bytes, str]]): Deprecated, use `correlation_data_bytes` instead. Opaque binary data used to correlate between publish messages, as a potential method for request-response implementation. Not internally meaningful to MQTT5. For incoming publishes, this will be a utf8 string (if correlation data exists and it's convertible to utf-8) or None (either it didn't exist, or did but wasn't convertible)
correlation_data_bytes (Optional[Union[bytes, str]]): Opaque binary data used to correlate between publish messages, as a potential method for request-response implementation. Not internally meaningful to MQTT5. For outbound publishes, this field takes priority over `correlation_data`. For incoming publishes, this will be binary data if correlation data is set, otherwise it will be None.
subscription_identifiers (Sequence[int]): The subscription identifiers of all the subscriptions this message matched. This field is ignored on outbound publishes (setting it is a protocol error).
content_type (str): Property specifying the content type of the payload. Not internally meaningful to MQTT5.
user_properties (Sequence[UserProperty]): List of MQTT5 user properties included with the packet.
"""
Expand All @@ -1117,7 +1118,9 @@ class PublishPacket:
message_expiry_interval_sec: int = None
topic_alias: int = None
response_topic: str = None
correlation_data: Any = None # Unicode objects are converted to C strings using 'utf-8' encoding
correlation_data_bytes: 'Optional[Union[bytes, str]]' = None # binary data if correlation data exists on the packet
# Deprecated. Incoming publishes: a string if correlation data exists on the packet and is convertible to utf-8
correlation_data: 'Optional[Union[bytes, str]]' = None
subscription_identifiers: 'Sequence[int]' = None # ignore attempts to set but provide in received packets
content_type: str = None
user_properties: 'Sequence[UserProperty]' = None
Expand Down Expand Up @@ -1447,8 +1450,18 @@ def _on_publish(
if topic_alias_exists:
publish_packet.topic_alias = topic_alias
publish_packet.response_topic = response_topic
publish_packet.correlation_data = correlation_data
if publish_packet.subscription_identifiers is not None:

# hacky workaround to maintain behavioral backwards compatibility with deprecated parameter
if correlation_data is not None:
# `correlation_data_bytes` always has the correlation data, as binary data
publish_packet.correlation_data_bytes = correlation_data
try:
# `correlation_data` contains the correlation data as a utf-8 string, if it can be converted
publish_packet.correlation_data = correlation_data.decode("utf-8")
except Exception:
pass

if subscription_identifiers_tuples is not None:
publish_packet.subscription_identifiers = [subscription_identifier
for (subscription_identifier) in subscription_identifiers_tuples]
publish_packet.content_type = content_type
Expand Down Expand Up @@ -1768,7 +1781,7 @@ def __init__(self, client_options: ClientOptions):
will.message_expiry_interval_sec,
will.topic_alias,
will.response_topic,
will.correlation_data,
will.correlation_data_bytes or will.correlation_data,
will.content_type,
will.user_properties,
client_options.session_behavior,
Expand Down Expand Up @@ -1865,7 +1878,7 @@ def puback(error_code, qos, reason_code, reason_string, user_properties_tuples):
publish_packet.message_expiry_interval_sec,
publish_packet.topic_alias,
publish_packet.response_topic,
publish_packet.correlation_data,
publish_packet.correlation_data_bytes or publish_packet.correlation_data,
publish_packet.content_type,
publish_packet.user_properties,
puback)
Expand Down
4 changes: 2 additions & 2 deletions source/mqtt5_client.c
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *pu
result = PyObject_CallMethod(
client->client_core,
"_on_publish",
"(y#iOs#OiOIOHs#z#Os#O)",
"(y#iOs#OiOIOHs#y#Os#O)",
/* y */ publish_packet->payload.ptr,
/* # */ publish_packet->payload.len,
/* i */ (int)publish_packet->qos,
Expand All @@ -276,7 +276,7 @@ static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *pu
/* H */ (unsigned short)(publish_packet->topic_alias ? *publish_packet->topic_alias : 0),
/* s */ publish_packet->response_topic ? publish_packet->response_topic->ptr : NULL,
/* # */ publish_packet->response_topic ? publish_packet->response_topic->len : 0,
/* z */ publish_packet->correlation_data ? publish_packet->correlation_data->ptr : NULL,
/* y */ publish_packet->correlation_data ? publish_packet->correlation_data->ptr : NULL,
/* # */ publish_packet->correlation_data ? publish_packet->correlation_data->len : 0,
/* O */ subscription_identifier_count > 0 ? subscription_identifier_list : Py_None,
/* s */ publish_packet->content_type ? publish_packet->content_type->ptr : NULL,
Expand Down
18 changes: 11 additions & 7 deletions test/test_mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,18 @@ def test_will(self):
host_name=test_input_endpoint,
port=8883,
will=Will(self.TEST_TOPIC, QoS.AT_LEAST_ONCE, self.TEST_MSG, False),
ping_timeout_ms=500,
keep_alive_secs=1
ping_timeout_ms=10000,
keep_alive_secs=30
)
connection.connect().result(TIMEOUT)

subscriber = Connection(
client=client,
client_id=create_client_id(),
host_name=test_input_endpoint,
port=8883
port=8883,
ping_timeout_ms=10000,
keep_alive_secs=30
)
subscriber.connect().result(TIMEOUT)

Expand All @@ -193,16 +195,18 @@ def on_message(**kwargs):
self.assertEqual(self.TEST_TOPIC, suback['topic'])
self.assertIs(QoS.AT_LEAST_ONCE, suback['qos'])

# Disconnect the will client to send the will
time.sleep(1.1) # wait 1.1 seconds to ensure we can make another client ID connect
# wait a few seconds to ensure we can make another client ID connect (don't trigger IoT Core limit)
time.sleep(2)

# Disconnect the will client to send the will by making another connection with the same client id
disconnecter = Connection(
client=client,
client_id=will_client_id,
host_name=test_input_endpoint,
port=8883,
will=Will(self.TEST_TOPIC, QoS.AT_LEAST_ONCE, self.TEST_MSG, False),
ping_timeout_ms=500,
keep_alive_secs=1
ping_timeout_ms=10000,
keep_alive_secs=30
)
disconnecter.connect().result(TIMEOUT)

Expand Down
167 changes: 167 additions & 0 deletions test/test_mqtt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,6 +1230,98 @@ def test_operation_will(self):
client2.stop()
callbacks2.future_stopped.result(TIMEOUT)

def do_will_correlation_data_test(self, outbound_correlation_data_bytes, outbound_correlation_data,
expected_correlation_data_bytes, expected_correlation_data):
input_host_name = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_HOST")
input_cert = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_RSA_CERT")
input_key = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_RSA_KEY")

client_id_publisher = create_client_id()
topic_filter = "test/MQTT5_Binding_Python_" + client_id_publisher

payload = "TEST WILL"
payload_bytes = payload.encode("utf-8")

will_packet = mqtt5.PublishPacket(
payload="TEST WILL",
qos=mqtt5.QoS.AT_LEAST_ONCE,
topic=topic_filter,
correlation_data_bytes=outbound_correlation_data_bytes,
correlation_data=outbound_correlation_data
)

tls_ctx_options = io.TlsContextOptions.create_client_with_mtls_from_path(
input_cert,
input_key
)

client_options1 = mqtt5.ClientOptions(
host_name=input_host_name,
port=8883
)
client_options1.connect_options = mqtt5.ConnectPacket(client_id=client_id_publisher,
will_delay_interval_sec=0,
will=will_packet)
client_options1.tls_ctx = io.ClientTlsContext(tls_ctx_options)
callbacks1 = Mqtt5TestCallbacks()
client1 = self._create_client(client_options=client_options1, callbacks=callbacks1)
client1.start()
callbacks1.future_connection_success.result(TIMEOUT)

client_options2 = mqtt5.ClientOptions(
host_name=input_host_name,
port=8883
)
client_options2.connect_options = mqtt5.ConnectPacket(client_id=create_client_id())
client_options2.tls_ctx = io.ClientTlsContext(tls_ctx_options)
callbacks2 = Mqtt5TestCallbacks()
client2 = self._create_client(client_options=client_options2, callbacks=callbacks2)
client2.start()
callbacks2.future_connection_success.result(TIMEOUT)

subscriptions = []
subscriptions.append(mqtt5.Subscription(topic_filter=topic_filter, qos=mqtt5.QoS.AT_LEAST_ONCE))
subscribe_packet = mqtt5.SubscribePacket(
subscriptions=subscriptions)
subscribe_future = client2.subscribe(subscribe_packet=subscribe_packet)
suback_packet = subscribe_future.result(TIMEOUT)
self.assertIsInstance(suback_packet, mqtt5.SubackPacket)

disconnect_packet = mqtt5.DisconnectPacket(reason_code=mqtt5.DisconnectReasonCode.DISCONNECT_WITH_WILL_MESSAGE)
client1.stop(disconnect_packet=disconnect_packet)
callbacks1.future_stopped.result(TIMEOUT)

received_will = callbacks2.future_publish_received.result(TIMEOUT)
self.assertIsInstance(received_will, mqtt5.PublishPacket)
self.assertEqual(received_will.payload, payload_bytes)
self.assertEqual(received_will.correlation_data_bytes, expected_correlation_data_bytes)
self.assertEqual(received_will.correlation_data, expected_correlation_data)

client2.stop()
callbacks2.future_stopped.result(TIMEOUT)

def test_will_correlation_data_bytes_binary(self):
correlation_data = bytearray(os.urandom(64))
self.do_will_correlation_data_test(correlation_data, None, correlation_data, None)

def test_will_correlation_data_bytes_string(self):
correlation_data = "CorrelationData"
correlation_data_as_bytes = correlation_data.encode('utf-8')
self.do_correlation_data_test(correlation_data, None, correlation_data_as_bytes, correlation_data)

def test_will_correlation_data_binary(self):
correlation_data = bytearray(os.urandom(64))
self.do_correlation_data_test(None, correlation_data, correlation_data, None)

def test_will_correlation_data_string(self):
correlation_data = "CorrelationData"
correlation_data_as_bytes = correlation_data.encode('utf-8')
self.do_correlation_data_test(None, correlation_data, correlation_data_as_bytes, correlation_data)

def test_will_correlation_data_bytes_binary_precedence(self):
correlation_data = bytearray(os.urandom(64))
self.do_correlation_data_test(correlation_data, "Ignored", correlation_data, None)

def test_operation_binary_publish(self):
input_host_name = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_HOST")
input_cert = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_RSA_CERT")
Expand Down Expand Up @@ -1292,6 +1384,81 @@ def test_operation_binary_publish(self):
client.stop()
callbacks.future_stopped.result(TIMEOUT)

def do_correlation_data_test(self, outbound_correlation_data_bytes, outbound_correlation_data,
expected_correlation_data_bytes, expected_correlation_data):
input_host_name = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_HOST")
input_cert = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_RSA_CERT")
input_key = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_RSA_KEY")

client_id = create_client_id()
topic_filter = "test/MQTT5_Binding_Python_" + client_id
payload = bytearray(os.urandom(256))

client_options = mqtt5.ClientOptions(
host_name=input_host_name,
port=8883
)
tls_ctx_options = io.TlsContextOptions.create_client_with_mtls_from_path(
input_cert,
input_key
)
client_options.tls_ctx = io.ClientTlsContext(tls_ctx_options)
callbacks = Mqtt5TestCallbacks()
client = self._create_client(client_options=client_options, callbacks=callbacks)
client.start()
callbacks.future_connection_success.result(TIMEOUT)

subscriptions = []
subscriptions.append(mqtt5.Subscription(topic_filter=topic_filter, qos=mqtt5.QoS.AT_LEAST_ONCE))
subscribe_packet = mqtt5.SubscribePacket(
subscriptions=subscriptions)
subscribe_future = client.subscribe(subscribe_packet=subscribe_packet)
suback_packet = subscribe_future.result(TIMEOUT)
self.assertIsInstance(suback_packet, mqtt5.SubackPacket)

publish_packet = mqtt5.PublishPacket(
payload=payload,
topic=topic_filter,
qos=mqtt5.QoS.AT_LEAST_ONCE,
correlation_data_bytes=outbound_correlation_data_bytes,
correlation_data=outbound_correlation_data)

publish_future = client.publish(publish_packet=publish_packet)
publish_completion_data = publish_future.result(TIMEOUT)
puback_packet = publish_completion_data.puback
self.assertIsInstance(puback_packet, mqtt5.PubackPacket)

received_publish = callbacks.future_publish_received.result(TIMEOUT)
self.assertIsInstance(received_publish, mqtt5.PublishPacket)
self.assertEqual(received_publish.payload, payload)
self.assertEqual(received_publish.correlation_data_bytes, expected_correlation_data_bytes)
self.assertEqual(received_publish.correlation_data, expected_correlation_data)

client.stop()
callbacks.future_stopped.result(TIMEOUT)

def test_operation_publish_correlation_data_bytes_binary(self):
correlation_data = bytearray(os.urandom(64))
self.do_correlation_data_test(correlation_data, None, correlation_data, None)

def test_operation_publish_correlation_data_bytes_string(self):
correlation_data = "CorrelationData"
correlation_data_as_bytes = correlation_data.encode('utf-8')
self.do_correlation_data_test(correlation_data, None, correlation_data_as_bytes, correlation_data)

def test_operation_publish_correlation_data_binary(self):
correlation_data = bytearray(os.urandom(64))
self.do_correlation_data_test(None, correlation_data, correlation_data, None)

def test_operation_publish_correlation_data_string(self):
correlation_data = "CorrelationData"
correlation_data_as_bytes = correlation_data.encode('utf-8')
self.do_correlation_data_test(None, correlation_data, correlation_data_as_bytes, correlation_data)

def test_operation_publish_correlation_data_bytes_binary_precedence(self):
correlation_data = bytearray(os.urandom(64))
self.do_correlation_data_test(correlation_data, "Ignored", correlation_data, None)

# ==============================================================
# OPERATION ERROR TEST CASES
# ==============================================================
Expand Down

0 comments on commit b5c7972

Please sign in to comment.