diff --git a/awscrt/mqtt.py b/awscrt/mqtt.py index 1da0b0413..2f67c8dc3 100644 --- a/awscrt/mqtt.py +++ b/awscrt/mqtt.py @@ -14,6 +14,7 @@ import awscrt.exceptions from awscrt.http import HttpProxyOptions, HttpRequest from awscrt.io import ClientBootstrap, ClientTlsContext, SocketOptions +from dataclasses import dataclass class QoS(IntEnum): @@ -150,6 +151,22 @@ def __init__(self, bootstrap=None, tls_ctx=None): self._binding = _awscrt.mqtt_client_new(bootstrap, tls_ctx) +@dataclass +class OperationStatisticsData: + """Dataclass containing some simple statistics about the current state of the connection's queue of operations + + Args: + incomplete_operation_count (int): total number of operations submitted to the connection that have not yet been completed. Unacked operations are a subset of this. + incomplete_operation_size (int): total packet size of operations submitted to the connection that have not yet been completed. Unacked operations are a subset of this. + unacked_operation_count (int): total number of operations that have been sent to the server and are waiting for a corresponding ACK before they can be completed. + unacked_operation_size (int): total packet size of operations that have been sent to the server and are waiting for a corresponding ACK before they can be completed. + """ + incomplete_operation_count: int = 0 + incomplete_operation_size: int = 0 + unacked_operation_count: int = 0 + unacked_operation_size: int = 0 + + class Connection(NativeResource): """MQTT client connection. @@ -703,6 +720,30 @@ def puback(packet_id, error_code): return future, packet_id + def get_stats(self): + """Queries the connection's internal statistics for incomplete operations. + + Returns: + A future with a (:class:`OperationStatisticsData`) + """ + + future = Future() + + def get_stats_result( + incomplete_operation_count, + incomplete_operation_size, + unacked_operation_count, + unacked_operation_size): + operation_statistics_data = OperationStatisticsData( + incomplete_operation_count, + incomplete_operation_size, + unacked_operation_count, + unacked_operation_size) + future.set_result(operation_statistics_data) + + _awscrt.mqtt_client_connection_get_stats(self._binding, get_stats_result) + return future + class WebsocketHandshakeTransformArgs: """ diff --git a/crt/aws-c-mqtt b/crt/aws-c-mqtt index 6668ffd60..33c3455ce 160000 --- a/crt/aws-c-mqtt +++ b/crt/aws-c-mqtt @@ -1 +1 @@ -Subproject commit 6668ffd60607bf070c078d46b8c8f4b2cecfc1db +Subproject commit 33c3455cec82b16feb940e12006cefd7b3ef4194 diff --git a/source/module.c b/source/module.c index 20bd856f6..d87b4dca7 100644 --- a/source/module.c +++ b/source/module.c @@ -706,6 +706,7 @@ static PyMethodDef s_module_methods[] = { AWS_PY_METHOD_DEF(mqtt_client_connection_unsubscribe, METH_VARARGS), AWS_PY_METHOD_DEF(mqtt_client_connection_disconnect, METH_VARARGS), AWS_PY_METHOD_DEF(mqtt_ws_handshake_transform_complete, METH_VARARGS), + AWS_PY_METHOD_DEF(mqtt_client_connection_get_stats, METH_VARARGS), /* MQTT5 Client */ AWS_PY_METHOD_DEF(mqtt5_client_new, METH_VARARGS), diff --git a/source/mqtt_client_connection.c b/source/mqtt_client_connection.c index b408f113b..4e7e6445b 100644 --- a/source/mqtt_client_connection.c +++ b/source/mqtt_client_connection.c @@ -1172,3 +1172,52 @@ PyObject *aws_py_mqtt_client_connection_disconnect(PyObject *self, PyObject *arg Py_RETURN_NONE; } + +PyObject *aws_py_mqtt_client_connection_get_stats(PyObject *self, PyObject *args) { + (void)self; + bool success = false; + + PyObject *impl_capsule; + PyObject *get_stats_callback_fn_py; + + if (!PyArg_ParseTuple(args, "OO", &impl_capsule, &get_stats_callback_fn_py)) { + return NULL; + } + + struct mqtt_connection_binding *connection = + PyCapsule_GetPointer(impl_capsule, s_capsule_name_mqtt_client_connection); + if (!connection) { + return NULL; + } + + /* These must be DECREF'd when function ends */ + PyObject *result = NULL; + + struct aws_mqtt_connection_operation_statistics stats; + AWS_ZERO_STRUCT(stats); + + aws_mqtt_client_connection_get_stats(connection->native, &stats); + + result = PyObject_CallFunction( + get_stats_callback_fn_py, + "(KKKK)", + /* K */ (unsigned long long)stats.incomplete_operation_count, + /* K */ (unsigned long long)stats.incomplete_operation_size, + /* K */ (unsigned long long)stats.unacked_operation_count, + /* K */ (unsigned long long)stats.unacked_operation_size); + if (!result) { + PyErr_WriteUnraisable(PyErr_Occurred()); + goto done; + } + + success = true; + +done: + + Py_XDECREF(result); + + if (success) { + Py_RETURN_NONE; + } + return NULL; +} diff --git a/source/mqtt_client_connection.h b/source/mqtt_client_connection.h index fc97f6c30..6bdc05017 100644 --- a/source/mqtt_client_connection.h +++ b/source/mqtt_client_connection.h @@ -20,6 +20,7 @@ PyObject *aws_py_mqtt_client_connection_on_message(PyObject *self, PyObject *arg PyObject *aws_py_mqtt_client_connection_unsubscribe(PyObject *self, PyObject *args); PyObject *aws_py_mqtt_client_connection_resubscribe_existing_topics(PyObject *self, PyObject *args); PyObject *aws_py_mqtt_client_connection_disconnect(PyObject *self, PyObject *args); +PyObject *aws_py_mqtt_client_connection_get_stats(PyObject *self, PyObject *args); PyObject *aws_py_mqtt_ws_handshake_transform_complete(PyObject *self, PyObject *args); diff --git a/test/test_mqtt.py b/test/test_mqtt.py index 88e1fe7c1..35100e58b 100644 --- a/test/test_mqtt.py +++ b/test/test_mqtt.py @@ -237,6 +237,64 @@ def test_connect_disconnect_with_default_singletons(self): EventLoopGroup.release_static_default() DefaultHostResolver.release_static_default() + def test_connect_publish_wait_statistics_disconnect(self): + connection = self._create_connection() + connection.connect().result(TIMEOUT) + + # check operation statistics + statistics = connection.get_stats().result(TIMEOUT) + self.assertEqual(statistics.incomplete_operation_count, 0) + self.assertEqual(statistics.incomplete_operation_size, 0) + self.assertEqual(statistics.unacked_operation_count, 0) + self.assertEqual(statistics.unacked_operation_size, 0) + + # publish + published, packet_id = connection.publish(self.TEST_TOPIC, self.TEST_MSG, QoS.AT_LEAST_ONCE) + puback = published.result(TIMEOUT) + self.assertEqual(packet_id, puback['packet_id']) + + # check operation statistics + statistics = connection.get_stats().result(TIMEOUT) + self.assertEqual(statistics.incomplete_operation_count, 0) + self.assertEqual(statistics.incomplete_operation_size, 0) + self.assertEqual(statistics.unacked_operation_count, 0) + self.assertEqual(statistics.unacked_operation_size, 0) + + # disconnect + connection.disconnect().result(TIMEOUT) + + def test_connect_publish_statistics_wait_disconnect(self): + connection = self._create_connection() + connection.connect().result(TIMEOUT) + + # publish + published, packet_id = connection.publish(self.TEST_TOPIC, self.TEST_MSG, QoS.AT_LEAST_ONCE) + # Per packet: (The size of the topic, the size of the payload, 2 for the header and 2 for the packet ID) + expected_size = len(self.TEST_TOPIC) + len(self.TEST_MSG) + 4 + + # check operation statistics + statistics = connection.get_stats().result(TIMEOUT) + self.assertEqual(statistics.incomplete_operation_count, 1) + self.assertEqual(statistics.incomplete_operation_size, expected_size) + # NOTE: Unacked will be zero because we have not invoked the future yet + # and so it has not had time to move to the socket + self.assertEqual(statistics.unacked_operation_count, 0) + self.assertEqual(statistics.unacked_operation_size, 0) + + # wait for PubAck + puback = published.result(TIMEOUT) + self.assertEqual(packet_id, puback['packet_id']) + + # check operation statistics + statistics = connection.get_stats().result(TIMEOUT) + self.assertEqual(statistics.incomplete_operation_count, 0) + self.assertEqual(statistics.incomplete_operation_size, 0) + self.assertEqual(statistics.unacked_operation_count, 0) + self.assertEqual(statistics.unacked_operation_size, 0) + + # disconnect + connection.disconnect().result(TIMEOUT) + if __name__ == 'main': unittest.main() diff --git a/test/test_mqtt5.py b/test/test_mqtt5.py index 117592944..b8487be6c 100644 --- a/test/test_mqtt5.py +++ b/test/test_mqtt5.py @@ -1226,6 +1226,46 @@ def test_interruption_qos1_publish(self): callbacks.future_stopped.result(TIMEOUT) + # ============================================================== + # MISC TEST CASES + # ============================================================== + + def test_operation_statistics_uc1(self): + client_id_publisher = create_client_id() + payload = "HELLO WORLD" + topic_filter = "test/MQTT5_Binding_Python_" + client_id_publisher + + client_options = mqtt5.ClientOptions("will be replaced", 0) + client_options.connect_options = mqtt5.ConnectPacket(client_id=client_id_publisher) + client1, callbacks = self._test_connect(auth_type=AuthType.DIRECT_MUTUAL_TLS, client_options=client_options) + + # Make sure the operation statistics are empty + statistics = client1.get_stats().result(TIMEOUT) + self.assertEqual(statistics.incomplete_operation_count, 0) + self.assertEqual(statistics.incomplete_operation_size, 0) + self.assertEqual(statistics.unacked_operation_count, 0) + self.assertEqual(statistics.unacked_operation_size, 0) + + publish_packet = mqtt5.PublishPacket( + payload=payload, + topic=topic_filter, + qos=mqtt5.QoS.AT_LEAST_ONCE) + + publishes = 10 + for x in range(publishes): + publish_future = client1.publish(publish_packet) + publish_future.result(TIMEOUT) + + # Make sure the operation statistics are empty + statistics = client1.get_stats().result(TIMEOUT) + self.assertEqual(statistics.incomplete_operation_count, 0) + self.assertEqual(statistics.incomplete_operation_size, 0) + self.assertEqual(statistics.unacked_operation_count, 0) + self.assertEqual(statistics.unacked_operation_size, 0) + + client1.stop() + callbacks.future_stopped.result(TIMEOUT) + if __name__ == 'main': unittest.main()