Skip to content

Commit

Permalink
Mqtt311 operation statistics support (#437)
Browse files Browse the repository at this point in the history
MQTT311 operation statistics support
  • Loading branch information
TwistedTwigleg authored Jan 25, 2023
1 parent 755ceac commit 082ebfe
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 1 deletion.
41 changes: 41 additions & 0 deletions awscrt/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import awscrt.exceptions
from awscrt.http import HttpProxyOptions, HttpRequest
from awscrt.io import ClientBootstrap, ClientTlsContext, SocketOptions
from dataclasses import dataclass


class QoS(IntEnum):
Expand Down Expand Up @@ -150,6 +151,22 @@ def __init__(self, bootstrap=None, tls_ctx=None):
self._binding = _awscrt.mqtt_client_new(bootstrap, tls_ctx)


@dataclass
class OperationStatisticsData:
"""Dataclass containing some simple statistics about the current state of the connection's queue of operations
Args:
incomplete_operation_count (int): total number of operations submitted to the connection that have not yet been completed. Unacked operations are a subset of this.
incomplete_operation_size (int): total packet size of operations submitted to the connection that have not yet been completed. Unacked operations are a subset of this.
unacked_operation_count (int): total number of operations that have been sent to the server and are waiting for a corresponding ACK before they can be completed.
unacked_operation_size (int): total packet size of operations that have been sent to the server and are waiting for a corresponding ACK before they can be completed.
"""
incomplete_operation_count: int = 0
incomplete_operation_size: int = 0
unacked_operation_count: int = 0
unacked_operation_size: int = 0


class Connection(NativeResource):
"""MQTT client connection.
Expand Down Expand Up @@ -703,6 +720,30 @@ def puback(packet_id, error_code):

return future, packet_id

def get_stats(self):
"""Queries the connection's internal statistics for incomplete operations.
Returns:
A future with a (:class:`OperationStatisticsData`)
"""

future = Future()

def get_stats_result(
incomplete_operation_count,
incomplete_operation_size,
unacked_operation_count,
unacked_operation_size):
operation_statistics_data = OperationStatisticsData(
incomplete_operation_count,
incomplete_operation_size,
unacked_operation_count,
unacked_operation_size)
future.set_result(operation_statistics_data)

_awscrt.mqtt_client_connection_get_stats(self._binding, get_stats_result)
return future


class WebsocketHandshakeTransformArgs:
"""
Expand Down
2 changes: 1 addition & 1 deletion crt/aws-c-mqtt
1 change: 1 addition & 0 deletions source/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,7 @@ static PyMethodDef s_module_methods[] = {
AWS_PY_METHOD_DEF(mqtt_client_connection_unsubscribe, METH_VARARGS),
AWS_PY_METHOD_DEF(mqtt_client_connection_disconnect, METH_VARARGS),
AWS_PY_METHOD_DEF(mqtt_ws_handshake_transform_complete, METH_VARARGS),
AWS_PY_METHOD_DEF(mqtt_client_connection_get_stats, METH_VARARGS),

/* MQTT5 Client */
AWS_PY_METHOD_DEF(mqtt5_client_new, METH_VARARGS),
Expand Down
49 changes: 49 additions & 0 deletions source/mqtt_client_connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -1172,3 +1172,52 @@ PyObject *aws_py_mqtt_client_connection_disconnect(PyObject *self, PyObject *arg

Py_RETURN_NONE;
}

PyObject *aws_py_mqtt_client_connection_get_stats(PyObject *self, PyObject *args) {
(void)self;
bool success = false;

PyObject *impl_capsule;
PyObject *get_stats_callback_fn_py;

if (!PyArg_ParseTuple(args, "OO", &impl_capsule, &get_stats_callback_fn_py)) {
return NULL;
}

struct mqtt_connection_binding *connection =
PyCapsule_GetPointer(impl_capsule, s_capsule_name_mqtt_client_connection);
if (!connection) {
return NULL;
}

/* These must be DECREF'd when function ends */
PyObject *result = NULL;

struct aws_mqtt_connection_operation_statistics stats;
AWS_ZERO_STRUCT(stats);

aws_mqtt_client_connection_get_stats(connection->native, &stats);

result = PyObject_CallFunction(
get_stats_callback_fn_py,
"(KKKK)",
/* K */ (unsigned long long)stats.incomplete_operation_count,
/* K */ (unsigned long long)stats.incomplete_operation_size,
/* K */ (unsigned long long)stats.unacked_operation_count,
/* K */ (unsigned long long)stats.unacked_operation_size);
if (!result) {
PyErr_WriteUnraisable(PyErr_Occurred());
goto done;
}

success = true;

done:

Py_XDECREF(result);

if (success) {
Py_RETURN_NONE;
}
return NULL;
}
1 change: 1 addition & 0 deletions source/mqtt_client_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ PyObject *aws_py_mqtt_client_connection_on_message(PyObject *self, PyObject *arg
PyObject *aws_py_mqtt_client_connection_unsubscribe(PyObject *self, PyObject *args);
PyObject *aws_py_mqtt_client_connection_resubscribe_existing_topics(PyObject *self, PyObject *args);
PyObject *aws_py_mqtt_client_connection_disconnect(PyObject *self, PyObject *args);
PyObject *aws_py_mqtt_client_connection_get_stats(PyObject *self, PyObject *args);

PyObject *aws_py_mqtt_ws_handshake_transform_complete(PyObject *self, PyObject *args);

Expand Down
58 changes: 58 additions & 0 deletions test/test_mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,64 @@ def test_connect_disconnect_with_default_singletons(self):
EventLoopGroup.release_static_default()
DefaultHostResolver.release_static_default()

def test_connect_publish_wait_statistics_disconnect(self):
connection = self._create_connection()
connection.connect().result(TIMEOUT)

# check operation statistics
statistics = connection.get_stats().result(TIMEOUT)
self.assertEqual(statistics.incomplete_operation_count, 0)
self.assertEqual(statistics.incomplete_operation_size, 0)
self.assertEqual(statistics.unacked_operation_count, 0)
self.assertEqual(statistics.unacked_operation_size, 0)

# publish
published, packet_id = connection.publish(self.TEST_TOPIC, self.TEST_MSG, QoS.AT_LEAST_ONCE)
puback = published.result(TIMEOUT)
self.assertEqual(packet_id, puback['packet_id'])

# check operation statistics
statistics = connection.get_stats().result(TIMEOUT)
self.assertEqual(statistics.incomplete_operation_count, 0)
self.assertEqual(statistics.incomplete_operation_size, 0)
self.assertEqual(statistics.unacked_operation_count, 0)
self.assertEqual(statistics.unacked_operation_size, 0)

# disconnect
connection.disconnect().result(TIMEOUT)

def test_connect_publish_statistics_wait_disconnect(self):
connection = self._create_connection()
connection.connect().result(TIMEOUT)

# publish
published, packet_id = connection.publish(self.TEST_TOPIC, self.TEST_MSG, QoS.AT_LEAST_ONCE)
# Per packet: (The size of the topic, the size of the payload, 2 for the header and 2 for the packet ID)
expected_size = len(self.TEST_TOPIC) + len(self.TEST_MSG) + 4

# check operation statistics
statistics = connection.get_stats().result(TIMEOUT)
self.assertEqual(statistics.incomplete_operation_count, 1)
self.assertEqual(statistics.incomplete_operation_size, expected_size)
# NOTE: Unacked will be zero because we have not invoked the future yet
# and so it has not had time to move to the socket
self.assertEqual(statistics.unacked_operation_count, 0)
self.assertEqual(statistics.unacked_operation_size, 0)

# wait for PubAck
puback = published.result(TIMEOUT)
self.assertEqual(packet_id, puback['packet_id'])

# check operation statistics
statistics = connection.get_stats().result(TIMEOUT)
self.assertEqual(statistics.incomplete_operation_count, 0)
self.assertEqual(statistics.incomplete_operation_size, 0)
self.assertEqual(statistics.unacked_operation_count, 0)
self.assertEqual(statistics.unacked_operation_size, 0)

# disconnect
connection.disconnect().result(TIMEOUT)


if __name__ == 'main':
unittest.main()
40 changes: 40 additions & 0 deletions test/test_mqtt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,46 @@ def test_interruption_qos1_publish(self):

callbacks.future_stopped.result(TIMEOUT)

# ==============================================================
# MISC TEST CASES
# ==============================================================

def test_operation_statistics_uc1(self):
client_id_publisher = create_client_id()
payload = "HELLO WORLD"
topic_filter = "test/MQTT5_Binding_Python_" + client_id_publisher

client_options = mqtt5.ClientOptions("will be replaced", 0)
client_options.connect_options = mqtt5.ConnectPacket(client_id=client_id_publisher)
client1, callbacks = self._test_connect(auth_type=AuthType.DIRECT_MUTUAL_TLS, client_options=client_options)

# Make sure the operation statistics are empty
statistics = client1.get_stats().result(TIMEOUT)
self.assertEqual(statistics.incomplete_operation_count, 0)
self.assertEqual(statistics.incomplete_operation_size, 0)
self.assertEqual(statistics.unacked_operation_count, 0)
self.assertEqual(statistics.unacked_operation_size, 0)

publish_packet = mqtt5.PublishPacket(
payload=payload,
topic=topic_filter,
qos=mqtt5.QoS.AT_LEAST_ONCE)

publishes = 10
for x in range(publishes):
publish_future = client1.publish(publish_packet)
publish_future.result(TIMEOUT)

# Make sure the operation statistics are empty
statistics = client1.get_stats().result(TIMEOUT)
self.assertEqual(statistics.incomplete_operation_count, 0)
self.assertEqual(statistics.incomplete_operation_size, 0)
self.assertEqual(statistics.unacked_operation_count, 0)
self.assertEqual(statistics.unacked_operation_size, 0)

client1.stop()
callbacks.future_stopped.result(TIMEOUT)


if __name__ == 'main':
unittest.main()

0 comments on commit 082ebfe

Please sign in to comment.