From 57422909408fc826c6de7be0f6b4703247950105 Mon Sep 17 00:00:00 2001 From: Alfred G <28123637+alfred2g@users.noreply.github.com> Date: Tue, 9 Jan 2024 13:34:08 -0800 Subject: [PATCH] Shared Subscription tests (#540) --- test/test_mqtt5.py | 162 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) diff --git a/test/test_mqtt5.py b/test/test_mqtt5.py index 915616f00..6446285e6 100644 --- a/test/test_mqtt5.py +++ b/test/test_mqtt5.py @@ -4,6 +4,7 @@ from concurrent.futures import Future from awscrt import mqtt5, io, http, exceptions from test import NativeResourceTest +from threading import Lock import os import unittest import uuid @@ -1008,6 +1009,167 @@ def test_operation_sub_unsub(self): client.stop() callbacks.future_stopped.result(TIMEOUT) + sub1_callbacks = False + sub2_callbacks = False + total_callbacks = 0 + all_packets_received = Future() + mutex = Lock() + received_subscriptions = [0] * 10 + + def subscriber1_callback(self, publish_received_data: mqtt5.PublishReceivedData): + self.mutex.acquire() + var = publish_received_data.publish_packet.payload + self.received_subscriptions[int(var)] = 1 + self.sub1_callbacks = True + self.total_callbacks = self.total_callbacks + 1 + if self.total_callbacks == 10: + self.all_packets_received.set_result(None) + self.mutex.release() + + def subscriber2_callback(self, publish_received_data: mqtt5.PublishReceivedData): + self.mutex.acquire() + var = publish_received_data.publish_packet.payload + self.received_subscriptions[int(var)] = 1 + self.sub2_callbacks = True + self.total_callbacks = self.total_callbacks + 1 + if self.total_callbacks == 10: + self.all_packets_received.set_result(None) + self.mutex.release() + + def test_operation_shared_subscription(self): + input_host_name = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_HOST") + input_cert = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_RSA_CERT") + input_key = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_RSA_KEY") + + client_id_subscriber1 = create_client_id() + client_id_subscriber2 = create_client_id() + client_id_publisher = create_client_id() + + testTopic = "test/MQTT5_Binding_Python_" + client_id_publisher + sharedTopicfilter = "$share/crttest/test/MQTT5_Binding_Python_" + client_id_publisher + + tls_ctx_options = io.TlsContextOptions.create_client_with_mtls_from_path( + input_cert, + input_key + ) + + # subscriber 1 + connect_subscriber1_options = mqtt5.ConnectPacket(client_id=client_id_subscriber1) + subscriber1_generic_callback = Mqtt5TestCallbacks() + subscriber1_options = mqtt5.ClientOptions( + host_name=input_host_name, + port=8883, + tls_ctx=io.ClientTlsContext(tls_ctx_options), + connect_options=connect_subscriber1_options, + on_publish_callback_fn=self.subscriber1_callback, + on_lifecycle_event_stopped_fn=subscriber1_generic_callback.on_lifecycle_stopped, + on_lifecycle_event_attempting_connect_fn=subscriber1_generic_callback.on_lifecycle_attempting_connect, + on_lifecycle_event_connection_success_fn=subscriber1_generic_callback.on_lifecycle_connection_success, + on_lifecycle_event_connection_failure_fn=subscriber1_generic_callback.on_lifecycle_connection_failure + ) + subscriber1_client = mqtt5.Client(client_options=subscriber1_options) + + # subscriber 2 + connect_subscriber2_options = mqtt5.ConnectPacket(client_id=client_id_subscriber2) + subscriber2_generic_callback = Mqtt5TestCallbacks() + subscriber2_options = mqtt5.ClientOptions( + host_name=input_host_name, + port=8883, + tls_ctx=io.ClientTlsContext(tls_ctx_options), + connect_options=connect_subscriber2_options, + on_publish_callback_fn=self.subscriber2_callback, + on_lifecycle_event_stopped_fn=subscriber2_generic_callback.on_lifecycle_stopped, + on_lifecycle_event_attempting_connect_fn=subscriber2_generic_callback.on_lifecycle_attempting_connect, + on_lifecycle_event_connection_success_fn=subscriber2_generic_callback.on_lifecycle_connection_success, + on_lifecycle_event_connection_failure_fn=subscriber2_generic_callback.on_lifecycle_connection_failure + ) + subscriber2_client = mqtt5.Client(client_options=subscriber2_options) + + # publisher + connect_publisher_options = mqtt5.ConnectPacket(client_id=client_id_publisher) + publisher_generic_callback = Mqtt5TestCallbacks() + + publisher_options = mqtt5.ClientOptions( + host_name=input_host_name, + port=8883, + tls_ctx=io.ClientTlsContext(tls_ctx_options), + connect_options=connect_publisher_options, + on_lifecycle_event_stopped_fn=publisher_generic_callback.on_lifecycle_stopped, + on_lifecycle_event_attempting_connect_fn=publisher_generic_callback.on_lifecycle_attempting_connect, + on_lifecycle_event_connection_success_fn=publisher_generic_callback.on_lifecycle_connection_success, + on_lifecycle_event_connection_failure_fn=publisher_generic_callback.on_lifecycle_connection_failure + ) + publisher_client = mqtt5.Client(client_options=publisher_options) + + print("Connecting all 3 clients\n") + subscriber1_client.start() + subscriber1_generic_callback.future_connection_success.result(TIMEOUT) + + subscriber2_client.start() + subscriber2_generic_callback.future_connection_success.result(TIMEOUT) + + publisher_client.start() + publisher_generic_callback.future_connection_success.result(TIMEOUT) + print("All clients connected\n") + + # Subscriber 1 + subscriptions = [] + subscriptions.append(mqtt5.Subscription(topic_filter=sharedTopicfilter, qos=mqtt5.QoS.AT_LEAST_ONCE)) + subscribe_packet = mqtt5.SubscribePacket( + subscriptions=subscriptions) + subscribe_future = subscriber1_client.subscribe(subscribe_packet=subscribe_packet) + suback_packet1 = subscribe_future.result(TIMEOUT) + self.assertIsInstance(suback_packet1, mqtt5.SubackPacket) + + # Subscriber 2 + subscriptions2 = [] + subscriptions2.append(mqtt5.Subscription(topic_filter=sharedTopicfilter, qos=mqtt5.QoS.AT_LEAST_ONCE)) + subscribe_packet2 = mqtt5.SubscribePacket( + subscriptions=subscriptions2) + subscribe_future2 = subscriber2_client.subscribe(subscribe_packet=subscribe_packet2) + suback_packet2 = subscribe_future2.result(TIMEOUT) + self.assertIsInstance(suback_packet2, mqtt5.SubackPacket) + + publishes = 10 + for x in range(0, publishes): + packet = mqtt5.PublishPacket( + payload=f"{x}", + qos=mqtt5.QoS.AT_LEAST_ONCE, + topic=testTopic + ) + publish_future = publisher_client.publish(packet) + publish_future.result(TIMEOUT) + + self.all_packets_received.result(TIMEOUT) + + topic_filters = [] + topic_filters.append(testTopic) + unsubscribe_packet = mqtt5.UnsubscribePacket(topic_filters=testTopic) + + unsubscribe_future = subscriber1_client.unsubscribe(unsubscribe_packet) + unsuback_packet = unsubscribe_future.result(TIMEOUT) + self.assertIsInstance(unsuback_packet, mqtt5.UnsubackPacket) + + unsubscribe_future = subscriber2_client.unsubscribe(unsubscribe_packet) + unsuback_packet = unsubscribe_future.result(TIMEOUT) + self.assertIsInstance(unsuback_packet, mqtt5.UnsubackPacket) + + self.assertEqual(self.sub1_callbacks, True) + self.assertEqual(self.sub2_callbacks, True) + self.assertEqual(self.total_callbacks, 10) + + for e in self.received_subscriptions: + self.assertEqual(e, 1) + + subscriber1_client.stop() + subscriber1_generic_callback.future_stopped.result(TIMEOUT) + + subscriber2_client.stop() + subscriber2_generic_callback.future_stopped.result(TIMEOUT) + + publisher_client.stop() + publisher_generic_callback.future_stopped.result(TIMEOUT) + def test_operation_will(self): input_host_name = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_HOST") input_cert = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_RSA_CERT")