diff --git a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/BUILD.bazel b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/BUILD.bazel index 2bbf804b0c5..17b9e64f70e 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/BUILD.bazel +++ b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/BUILD.bazel @@ -48,3 +48,11 @@ pl_cc_test( ":cc_library", ], ) + +pl_cc_test( + name = "stitcher_test", + srcs = ["stitcher_test.cc"], + deps = [ + ":cc_library", + ], +) diff --git a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse.cc b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse.cc index a111c7c818d..35f5ba22fd2 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse.cc +++ b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include "src/common/base/base.h" #include "src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse.h" @@ -37,24 +38,6 @@ namespace protocols { namespace mqtt { -enum class MqttControlPacketType : uint8_t { - CONNECT = 1, - CONNACK = 2, - PUBLISH = 3, - PUBACK = 4, - PUBREC = 5, - PUBREL = 6, - PUBCOMP = 7, - SUBSCRIBE = 8, - SUBACK = 9, - UNSUBSCRIBE = 10, - UNSUBACK = 11, - PINGREQ = 12, - PINGRESP = 13, - DISCONNECT = 14, - AUTH = 15 -}; - enum class PropertyCode : uint8_t { PayloadFormatIndicator = 0x01, MessageExpiryInterval = 0x02, @@ -653,7 +636,8 @@ ParseState ParsePayload(Message* result, BinaryDecoder* decoder, } } -ParseState ParseFrame(message_type_t type, std::string_view* buf, Message* result) { +ParseState ParseFrame(message_type_t type, std::string_view* buf, Message* result, + mqtt::StateWrapper* state) { CTX_DCHECK(type == message_type_t::kRequest || type == message_type_t::kResponse); if (buf->size() < 2) { return ParseState::kNeedsMoreData; @@ -723,6 +707,19 @@ ParseState ParseFrame(message_type_t type, std::string_view* buf, Message* resul return ParseState::kInvalid; } + // Updating the state for PUBLISH based on whether it is duplicate + if (control_packet_type == MqttControlPacketType::PUBLISH) { + if (result->dup) { + if (type == message_type_t::kRequest) { + state->send[std::tuple(result->header_fields["packet_identifier"], + result->header_fields["qos"])] += 1; + } else { + state->recv[std::tuple(result->header_fields["packet_identifier"], + result->header_fields["qos"])] += 1; + } + } + } + if (ParsePayload(result, &decoder, control_packet_type) == ParseState::kInvalid) { return ParseState::kInvalid; } @@ -735,8 +732,8 @@ ParseState ParseFrame(message_type_t type, std::string_view* buf, Message* resul template <> ParseState ParseFrame(message_type_t type, std::string_view* buf, mqtt::Message* result, - NoState* /*state*/) { - return mqtt::ParseFrame(type, buf, result); + mqtt::StateWrapper* state) { + return mqtt::ParseFrame(type, buf, result, state); } template <> @@ -745,6 +742,11 @@ size_t FindFrameBoundary(message_type_t /*type*/, std::string_vie return start_pos + buf.length(); } +template <> +mqtt::packet_id_t GetStreamID(mqtt::Message* message) { + return message->header_fields["packet_identifier"]; +} + } // namespace protocols } // namespace stirling } // namespace px diff --git a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse.h b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse.h index 7220550457b..a4fdc6ee982 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse.h +++ b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse.h @@ -35,11 +35,14 @@ namespace protocols { template <> ParseState ParseFrame(message_type_t type, std::string_view* buf, mqtt::Message* frame, - NoState* state); + mqtt::StateWrapper* state); template <> size_t FindFrameBoundary(message_type_t type, std::string_view buf, size_t start_pos, - NoState* state); + mqtt::StateWrapper* state); + +template <> +mqtt::packet_id_t GetStreamID(mqtt::Message* message); } // namespace protocols } // namespace stirling diff --git a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse_test.cc b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse_test.cc index 7c5eac3ee83..5d8352f8a2c 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse_test.cc +++ b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse_test.cc @@ -16,7 +16,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include #include #include @@ -32,6 +31,7 @@ class MQTTParserTest : public ::testing::Test {}; TEST_F(MQTTParserTest, Properties) { Message frame; + StateWrapper* state = nullptr; ParseState result_state; std::string_view frame_view; @@ -333,45 +333,45 @@ TEST_F(MQTTParserTest, Properties) { frame_view = CreateStringView(CharArrayStringView(payload_format_indicator_publish)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.properties["payload_format"], "utf-8"); frame = Message(); frame_view = CreateStringView(CharArrayStringView(message_expiry_interval_publish)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.properties["message_expiry_interval"], "65536000"); frame = Message(); frame_view = CreateStringView(CharArrayStringView(content_type_publish)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.properties["content_type"], "application/json"); frame = Message(); frame_view = CreateStringView(CharArrayStringView(response_topic_publish)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.properties["response_topic"], "ABCXYZ"); frame = Message(); frame_view = CreateStringView(CharArrayStringView(correlation_data_publish)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.properties["correlation_data"], "ABCXYZ"); frame = Message(); frame_view = CreateStringView(CharArrayStringView(subscription_id_subscribe)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.properties["subscription_id"], "14860801"); frame = Message(); frame_view = CreateStringView(CharArrayStringView(session_exp_int_recv_max_connect)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.properties["session_expiry_interval"], "1000000"); EXPECT_EQ(frame.properties["receive_maximum"], "20"); @@ -379,7 +379,7 @@ TEST_F(MQTTParserTest, Properties) { frame_view = CreateStringView( CharArrayStringView(assigned_cid_topic_alias_max_recv_max_connack)); - result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.properties["assigned_client_identifier"], "auto-C8063868-7804-3F66-06B8-BACD9F673CB0"); @@ -388,50 +388,50 @@ TEST_F(MQTTParserTest, Properties) { frame = Message(); frame_view = CreateStringView(CharArrayStringView(subscription_id_subscribe)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.properties["subscription_id"], "14860801"); frame = Message(); frame_view = CreateStringView(CharArrayStringView(auth_method_data_connect)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.properties["auth_method"], "SCRAM-SHA-256"); EXPECT_EQ(frame.properties["auth_data"], "client-first-message"); frame = Message(); frame_view = CreateStringView(CharArrayStringView(req_prob_info_connect)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.properties["request_problem_information"], "1"); frame = Message(); frame_view = CreateStringView(CharArrayStringView(will_delay_interval_connect)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.properties["will_delay_interval"], "30"); frame = Message(); frame_view = CreateStringView(CharArrayStringView(req_resp_info_connect)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.properties["request_response_information"], "1"); frame = Message(); frame_view = CreateStringView(CharArrayStringView(topic_alias_publish)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.properties["topic_alias"], "100"); frame = Message(); frame_view = CreateStringView(CharArrayStringView(user_prop_subscribe)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.properties["user-properties"], "{examplekey:examplevalue}"); frame = Message(); frame_view = CreateStringView(CharArrayStringView(max_packet_size_connect)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.properties["maximum_packet_size"], "1048576"); frame = Message(); @@ -439,6 +439,7 @@ TEST_F(MQTTParserTest, Properties) { TEST_F(MQTTParserTest, Payload) { Message frame; + StateWrapper* state = nullptr; ParseState result_state; std::string_view frame_view; @@ -534,7 +535,7 @@ TEST_F(MQTTParserTest, Payload) { 0x00}; frame_view = CreateStringView(CharArrayStringView(kConnectFrame)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.payload["will_topic"], "will-topic"); EXPECT_EQ(frame.payload["will_payload"], "goodbye"); @@ -542,13 +543,13 @@ TEST_F(MQTTParserTest, Payload) { frame = Message(); frame_view = CreateStringView(CharArrayStringView(kPublishFrame)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.payload["publish_message"], "hello world"); frame = Message(); frame_view = CreateStringView(CharArrayStringView(kSubscribeFrame)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.payload["topic_filter"], "test/topic"); std::map subscription_opts( @@ -557,19 +558,19 @@ TEST_F(MQTTParserTest, Payload) { frame = Message(); frame_view = CreateStringView(CharArrayStringView(kSubackFrame)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.payload["reason_code"], "0"); frame = Message(); frame_view = CreateStringView(CharArrayStringView(kUnsubscribeFrame)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.payload["topic_filter"], "test/topic"); frame = Message(); frame_view = CreateStringView(CharArrayStringView(kUnsubackFrame)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.payload["reason_code"], "0"); frame = Message(); @@ -577,6 +578,7 @@ TEST_F(MQTTParserTest, Payload) { TEST_F(MQTTParserTest, Headers) { Message frame; + StateWrapper* state = nullptr; std::string_view frame_view; ParseState result_state; @@ -727,7 +729,7 @@ TEST_F(MQTTParserTest, Headers) { 0x18}; frame_view = CreateStringView(CharArrayStringView(kConnectFrame)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.control_packet_type, 1); EXPECT_EQ(frame.header_fields["remaining_length"], 16); @@ -741,7 +743,7 @@ TEST_F(MQTTParserTest, Headers) { frame = Message(); frame_view = CreateStringView(CharArrayStringView(kConnackFrame)); - result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.control_packet_type, 2); EXPECT_EQ(frame.header_fields["remaining_length"], 53); @@ -750,7 +752,7 @@ TEST_F(MQTTParserTest, Headers) { frame = Message(); frame_view = CreateStringView(CharArrayStringView(kPublishFrame)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.control_packet_type, 3); EXPECT_EQ(frame.header_fields["remaining_length"], 26); @@ -761,7 +763,7 @@ TEST_F(MQTTParserTest, Headers) { frame = Message(); frame_view = CreateStringView(CharArrayStringView(kPubackFrame)); - result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.control_packet_type, 4); EXPECT_EQ(frame.header_fields["remaining_length"], 3); @@ -769,7 +771,7 @@ TEST_F(MQTTParserTest, Headers) { frame = Message(); frame_view = CreateStringView(CharArrayStringView(kPubrecFrame)); - result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.control_packet_type, 5); EXPECT_EQ(frame.header_fields["remaining_length"], 2); @@ -777,7 +779,7 @@ TEST_F(MQTTParserTest, Headers) { frame = Message(); frame_view = CreateStringView(CharArrayStringView(kPubrelFrame)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.control_packet_type, 6); EXPECT_EQ(frame.header_fields["remaining_length"], 2); @@ -785,7 +787,7 @@ TEST_F(MQTTParserTest, Headers) { frame = Message(); frame_view = CreateStringView(CharArrayStringView(kPubcompFrame)); - result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.control_packet_type, 7); EXPECT_EQ(frame.header_fields["remaining_length"], 2); @@ -793,7 +795,7 @@ TEST_F(MQTTParserTest, Headers) { frame = Message(); frame_view = CreateStringView(CharArrayStringView(kSubscribeFrame)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.control_packet_type, 8); EXPECT_EQ(frame.header_fields["remaining_length"], 16); @@ -801,7 +803,7 @@ TEST_F(MQTTParserTest, Headers) { frame = Message(); frame_view = CreateStringView(CharArrayStringView(kSubackFrame)); - result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.control_packet_type, 9); EXPECT_EQ(frame.header_fields["remaining_length"], 4); @@ -810,7 +812,7 @@ TEST_F(MQTTParserTest, Headers) { frame = Message(); frame_view = CreateStringView(CharArrayStringView(kUnsubscribeFrame)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.control_packet_type, 10); EXPECT_EQ(frame.header_fields["remaining_length"], 15); @@ -818,7 +820,7 @@ TEST_F(MQTTParserTest, Headers) { frame = Message(); frame_view = CreateStringView(CharArrayStringView(kUnsubackFrame)); - result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.control_packet_type, 11); EXPECT_EQ(frame.header_fields["remaining_length"], 4); @@ -826,21 +828,21 @@ TEST_F(MQTTParserTest, Headers) { frame = Message(); frame_view = CreateStringView(CharArrayStringView(kPingreqFrame)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.control_packet_type, 12); EXPECT_EQ(frame.header_fields["remaining_length"], 0); frame = Message(); frame_view = CreateStringView(CharArrayStringView(kPingrespFrame)); - result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.control_packet_type, 13); EXPECT_EQ(frame.header_fields["remaining_length"], 0); frame = Message(); frame_view = CreateStringView(CharArrayStringView(kDisconnectFrame)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.control_packet_type, 14); EXPECT_EQ(frame.header_fields["remaining_length"], 1); @@ -848,14 +850,14 @@ TEST_F(MQTTParserTest, Headers) { frame = Message(); frame_view = CreateStringView(CharArrayStringView(kAuthFrame_success)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.control_packet_type, 15); EXPECT_EQ(frame.header_fields["reason_code"], 0); frame = Message(); frame_view = CreateStringView(CharArrayStringView(kAuthFrame_cont_auth)); - result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); + result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame, state); ASSERT_EQ(result_state, ParseState::kSuccess); EXPECT_EQ(frame.control_packet_type, 15); EXPECT_EQ(frame.header_fields["reason_code"], 0x18); diff --git a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/stitcher.cc b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/stitcher.cc new file mode 100644 index 00000000000..54945328fe3 --- /dev/null +++ b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/stitcher.cc @@ -0,0 +1,225 @@ +/* + * Copyright 2018- The Pixie Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "src/stirling/source_connectors/socket_tracer/protocols/mqtt/stitcher.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "src/common/json/json.h" +#include "src/stirling/source_connectors/socket_tracer/protocols/mqtt/types.h" + +namespace px { +namespace stirling { +namespace protocols { +namespace mqtt { + +// MatchKey layout, || control_packet_type (4 bits) | dup (1 bit) | qos (2 bits) | retain (1 bit) || +typedef uint8_t MatchKey; + +constexpr MatchKey UnmatchedResp = 0xff; + +std::map MapRequestToResponse = { + // CONNECT to CONNACK + {0x10, 0x20}, + // PUBLISH QOS 0 to Dummy response + {0x30, UnmatchedResp}, + {0x31, UnmatchedResp}, + {0x38, UnmatchedResp}, + {0x39, UnmatchedResp}, + // PUBLISH QOS 1 to PUBACK + {0x32, 0x40}, + {0x33, 0x40}, + {0x3a, 0x40}, + {0x3b, 0x40}, + // PUBLISH QOS 2 to PUBREC + {0x34, 0x50}, + {0x35, 0x50}, + {0x3c, 0x50}, + {0x3d, 0x50}, + // PUBREL to PUBCOMP + {0x60, 0x70}, + // SUBSCRIBE to SUBACK + {0X80, 0X90}, + // UNSUBSCRIBE to UNSUBACK + {0xa0, 0xb0}, + // PINGREQ to PINGRESP + {0xc0, 0xd0}, + // DISCONNECT to Dummy response + {0xe0, UnmatchedResp}}; + +// Possible to have the server sending PUBLISH with same packet identifier as client PUBLISH before +// it sends PUBACK, causing server PUBLISH to be put into response deque instead of request deque. +// TODO(ChinmayaSharma-hue): Reverse logic to match requests that have erroneously been put into +// response deque + +MatchKey getMatchKey(mqtt::Message* frame) { + return (frame->control_packet_type << 4) | static_cast(frame->dup) << 3 | + (frame->header_fields["qos"] & 0x3) << 1 | static_cast(frame->retain); +} + +RecordsWithErrorCount StitchFrames( + absl::flat_hash_map>* req_frames, + absl::flat_hash_map>* resp_frames, mqtt::StateWrapper* state) { + std::vector entries; + int error_count = 0; + + // iterate through all deques of requests associated with a specific streamID and find the + // matching response + for (auto& [packet_id, req_deque] : *req_frames) { + // goal is to match the request to the closest appropriate response to the specific control type + // based on timestamp + + // get the response deque corresponding to the packet ID of the request deque + auto pos = resp_frames->find(packet_id); + // note that not finding a corresponding response deque is not indicative of error, as in + // case of MQTT packets that do not have responses like Publish with QOS 0 + std::deque empty_deque; + std::deque& resp_deque = (pos != resp_frames->end()) ? pos->second : empty_deque; + + // track the latest response timestamp to compare against request frame's timestamp later. + uint64_t latest_resp_ts = resp_deque.empty() ? 0 : resp_deque.back().timestamp_ns; + // finding the closest appropriate response from response deque in terms of timestamp and type + // for each request in the request deque + for (mqtt::Message& req_frame : req_deque) { + const MqttControlPacketType control_packet_type = + magic_enum::enum_cast(req_frame.control_packet_type).value(); + // If the frame is AUTH, then do not classify, as request AUTH first comes from the server + // side which might be classified as response, so would be present in the response deque and + // not the request deque + // TODO(ChinmayaSharma-hue): Handling of AUTH matching + if (control_packet_type == MqttControlPacketType::AUTH) { + req_frame.consumed = true; + continue; + } + // If the frame is PUBLISH, and there are duplicates in the deque, then mark the frame as + // consumed and match the latest duplicate with its response (if the response exists in the + // response deque) + if (control_packet_type == MqttControlPacketType::PUBLISH) { + std::tuple unique_publish_identifier = std::tuple( + req_frame.header_fields["packet_identifier"], req_frame.header_fields["qos"]); + if (req_frame.type == message_type_t::kRequest) { + auto it = state->send.find(unique_publish_identifier); + if (it != state->send.end() && it->second > 0) { + it->second -= 1; + req_frame.consumed = true; + continue; + } + } + + if (req_frame.type == message_type_t::kResponse) { + auto it = state->recv.find(unique_publish_identifier); + if (it != state->recv.end() && it->second > 0) { + it->second -= 1; + req_frame.consumed = true; + continue; + } + } + } + + // getting the appropriate response match value for the request match key + MatchKey request_match_key = getMatchKey(&req_frame); + auto iter = MapRequestToResponse.find(request_match_key); + if (iter == MapRequestToResponse.end()) { + VLOG(1) << absl::Substitute("Could not find any responses for frame type = $0", + request_match_key); + continue; + } + if (iter->second == UnmatchedResp) { + // Request without responses found + req_frame.consumed = true; + latest_resp_ts = req_frame.timestamp_ns + 1; + mqtt::Message dummy_resp; + entries.push_back({std::move(req_frame), std::move(dummy_resp)}); + continue; + } + MatchKey response_match_value = iter->second; + + // finding the first response frame with timestamp greater than request frame + auto first_timestamp_iter = + std::lower_bound(resp_deque.begin(), resp_deque.end(), req_frame.timestamp_ns, + [](const mqtt::Message& message, const uint64_t ts) { + return ts > message.timestamp_ns; + }); + if (first_timestamp_iter == resp_deque.end()) { + VLOG(1) << absl::Substitute("Could not find any responses after timestamp = $0", + req_frame.timestamp_ns); + continue; + } + + // finding the first appropriate response frame with the desired control packet type and flags + auto response_frame_iter = std::find_if( + first_timestamp_iter, resp_deque.end(), [response_match_value](mqtt::Message& message) { + return (getMatchKey(&message) == response_match_value) & !message.consumed; + }); + if (response_frame_iter == resp_deque.end()) { + VLOG(1) << absl::Substitute( + "Could not find any responses with control packet type and flag = $0", + response_match_value); + continue; + } + mqtt::Message& resp_frame = *response_frame_iter; + + req_frame.consumed = true; + resp_frame.consumed = true; + entries.push_back({std::move(req_frame), std::move(resp_frame)}); + } + + // clearing the req_deque and resp_deque + auto erase_until_iter = req_deque.begin(); + auto iter = req_deque.begin(); + while (iter != req_deque.end() && (iter->timestamp_ns < latest_resp_ts)) { + if (iter->consumed) { + ++erase_until_iter; + } + if (!iter->consumed && !(iter == req_deque.end() - 1) && ((erase_until_iter + 1)->consumed)) { + ++error_count; + ++erase_until_iter; + } + ++iter; + } + req_deque.erase(req_deque.begin(), erase_until_iter); + } + + // Verify which deque server side PUBLISH frames are inserted into. It's suspected that these + // PUBLISH requests will end up in the resp deque and will cause the resp deque cleanup logic to + // erroneously drop request frames + // TODO(ChinmayaSharma-hue): Verify that the frames in response deque are not request frames + // before dropping + + // iterate through all response dequeues to find out which ones haven't been consumed + for (auto& [packet_id, resp_deque] : *resp_frames) { + for (auto& resp : resp_deque) { + if (!resp.consumed) { + error_count++; + } + } + resp_deque.clear(); + } + + return {entries, error_count}; +} +} // namespace mqtt +} // namespace protocols +} // namespace stirling +} // namespace px diff --git a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/stitcher.h b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/stitcher.h new file mode 100644 index 00000000000..43b1056cbf6 --- /dev/null +++ b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/stitcher.h @@ -0,0 +1,60 @@ +/* + * Copyright 2018- The Pixie Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include + +#include "src/stirling/source_connectors/socket_tracer/protocols/common/interface.h" +#include "src/stirling/source_connectors/socket_tracer/protocols/mqtt/types.h" + +namespace px { +namespace stirling { +namespace protocols { +namespace mqtt { + +/** + * StitchFrames is the entry point of the MQTT Stitcher. It loops through the req_frames, + * matches them with the corresponding resp_frames, and optionally produces an entry to emit. + * + * @param req_frames: deque of all request frames. + * @param resp_frames: deque of all response frames. + * @param resp_frames: state holding send and recv maps, which are key-value pairs of (packet id, + * qos) and dup counter. + * @return A vector of entries to be appended to table store. + */ +RecordsWithErrorCount StitchFrames( + absl::flat_hash_map>* req_frames, + absl::flat_hash_map>* resp_frames, mqtt::StateWrapper* state); + +} // namespace mqtt + +template <> +inline RecordsWithErrorCount StitchFrames( + absl::flat_hash_map>* req_messages, + absl::flat_hash_map>* res_messages, + mqtt::StateWrapper* state) { + return mqtt::StitchFrames(req_messages, res_messages, state); +} + +} // namespace protocols +} // namespace stirling +} // namespace px diff --git a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/stitcher_test.cc b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/stitcher_test.cc new file mode 100644 index 00000000000..9ea7c482875 --- /dev/null +++ b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/stitcher_test.cc @@ -0,0 +1,469 @@ +/* + * Copyright 2018- The Pixie Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include + +#include "src/stirling/source_connectors/socket_tracer/protocols/common/test_utils.h" +#include "src/stirling/source_connectors/socket_tracer/protocols/mqtt/stitcher.h" +#include "src/stirling/source_connectors/socket_tracer/protocols/mqtt/test_utils.h" + +namespace px { +namespace stirling { +namespace protocols { +namespace mqtt { + +using testutils::CreateFrame; + +TEST(MqttStitcherTest, EmptyInputs) { + absl::flat_hash_map> req_map; + absl::flat_hash_map> resp_map; + + RecordsWithErrorCount result; + StateWrapper state = {.global = {}, .send = {}, .recv = {}}; + + result = StitchFrames(&req_map, &resp_map, &state); + + EXPECT_TRUE(AreAllDequesEmpty(resp_map)); + EXPECT_EQ(TotalDequeSize(req_map), 0); + EXPECT_EQ(result.error_count, 0); + EXPECT_EQ(result.records.size(), 0); +} + +TEST(MqttStitcherTest, OnlyRequests) { + absl::flat_hash_map> req_map; + absl::flat_hash_map> resp_map; + + RecordsWithErrorCount result; + StateWrapper state = {.global = {}, .send = {}, .recv = {}}; + + Message connect_frame, pingreq_frame; + connect_frame = CreateFrame(kRequest, MqttControlPacketType::CONNECT, 0, 0); + pingreq_frame = CreateFrame(kRequest, MqttControlPacketType::CONNACK, 0, 0); + + int t = 0; + connect_frame.timestamp_ns = ++t; + pingreq_frame.timestamp_ns = ++t; + + req_map[0].push_back(connect_frame); + req_map[0].push_back(pingreq_frame); + + result = StitchFrames(&req_map, &resp_map, &state); + + EXPECT_TRUE(AreAllDequesEmpty(resp_map)); + EXPECT_EQ(TotalDequeSize(req_map), 2); + EXPECT_EQ(result.error_count, 0); + EXPECT_EQ(result.records.size(), 0); +} + +TEST(MqttStitcherTest, OnlyResponses) { + absl::flat_hash_map> req_map; + absl::flat_hash_map> resp_map; + + RecordsWithErrorCount result; + StateWrapper state = {.global = {}, .send = {}, .recv = {}}; + + Message connack_frame, pingresp_frame; + connack_frame = CreateFrame(kResponse, MqttControlPacketType::CONNACK, 0, 0); + pingresp_frame = CreateFrame(kResponse, MqttControlPacketType::PINGRESP, 0, 0); + + int t = 0; + connack_frame.timestamp_ns = ++t; + pingresp_frame.timestamp_ns = ++t; + + resp_map[0].push_back(connack_frame); + resp_map[0].push_back(pingresp_frame); + + result = StitchFrames(&req_map, &resp_map, &state); + + EXPECT_TRUE(AreAllDequesEmpty(resp_map)); + EXPECT_EQ(result.error_count, 2); + EXPECT_EQ(result.records.size(), 0); +} + +TEST(MqttStitcherTest, MissingResponseBeforeNextResponse) { + absl::flat_hash_map> req_map; + absl::flat_hash_map> resp_map; + + RecordsWithErrorCount result; + StateWrapper state = {.global = {}, .send = {}, .recv = {}}; + + Message pub1_frame, puback_frame, sub_frame, suback_frame, unsub_frame, unsuback_frame; + pub1_frame = CreateFrame(kRequest, MqttControlPacketType::PUBLISH, 1, 1); + puback_frame = CreateFrame(kResponse, MqttControlPacketType::PUBACK, 1, 0); + sub_frame = CreateFrame(kRequest, MqttControlPacketType::SUBSCRIBE, 0, 0); + unsub_frame = CreateFrame(kRequest, MqttControlPacketType::UNSUBSCRIBE, 0, 0); + unsuback_frame = CreateFrame(kResponse, MqttControlPacketType::UNSUBACK, 0, 0); + + int t = 0; + sub_frame.timestamp_ns = ++t; + pub1_frame.timestamp_ns = ++t; + puback_frame.timestamp_ns = ++t; + unsub_frame.timestamp_ns = ++t; + unsuback_frame.timestamp_ns = ++t; + + req_map[1].push_back(sub_frame); + req_map[1].push_back(pub1_frame); + req_map[2].push_back(unsub_frame); + + resp_map[1].push_back(puback_frame); + resp_map[2].push_back(unsuback_frame); + + // Update the state for PUBLISH packet ID 1 and QOS 1 + state.send[std::tuple(1, 1)] = 0; + + result = StitchFrames(&req_map, &resp_map, &state); + EXPECT_TRUE(AreAllDequesEmpty(resp_map)); + EXPECT_EQ(TotalDequeSize(req_map), 0); + EXPECT_EQ(result.error_count, 1); + EXPECT_EQ(result.records.size(), 2); +} + +TEST(MqttStitcherTest, MissingResponseTailEnd) { + // Response not yet received for a particular packet identifier + absl::flat_hash_map> req_map; + absl::flat_hash_map> resp_map; + + RecordsWithErrorCount result; + StateWrapper state = {.global = {}, .send = {}, .recv = {}}; + + Message pub1_frame, sub_frame, suback_frame, unsub_frame, unsuback_frame; + pub1_frame = CreateFrame(kRequest, MqttControlPacketType::PUBLISH, 1, 1); + sub_frame = CreateFrame(kRequest, MqttControlPacketType::SUBSCRIBE, 0, 0); + suback_frame = CreateFrame(kResponse, MqttControlPacketType::SUBACK, 0, 0); + unsub_frame = CreateFrame(kRequest, MqttControlPacketType::UNSUBSCRIBE, 0, 0); + unsuback_frame = CreateFrame(kResponse, MqttControlPacketType::UNSUBACK, 0, 0); + + int t = 0; + sub_frame.timestamp_ns = ++t; + pub1_frame.timestamp_ns = ++t; + unsub_frame.timestamp_ns = ++t; + suback_frame.timestamp_ns = ++t; + unsuback_frame.timestamp_ns = ++t; + + req_map[1].push_back(sub_frame); + req_map[1].push_back(pub1_frame); + req_map[2].push_back(unsub_frame); + + resp_map[1].push_back(suback_frame); + resp_map[2].push_back(unsuback_frame); + + // Update the state for PUBLISH packet ID 1 and QOS 1 + state.send[std::tuple(1, 1)] = 0; + + result = StitchFrames(&req_map, &resp_map, &state); + + EXPECT_TRUE(AreAllDequesEmpty(resp_map)); + EXPECT_EQ(TotalDequeSize(req_map), 1); + EXPECT_EQ(result.error_count, 0); + EXPECT_EQ(result.records.size(), 2); + req_map.clear(); +} + +TEST(MqttStitcherTest, MissingRequest) { + // Test for packet stitching for packets that do not have packet identifiers + absl::flat_hash_map> req_map; + absl::flat_hash_map> resp_map; + + RecordsWithErrorCount result; + StateWrapper state = {.global = {}, .send = {}, .recv = {}}; + + Message connect_frame, connack_frame, pingresp_frame; + connect_frame = CreateFrame(kRequest, MqttControlPacketType::CONNECT, 0, 0); + connack_frame = CreateFrame(kResponse, MqttControlPacketType::CONNACK, 0, 0); + pingresp_frame = CreateFrame(kResponse, MqttControlPacketType::PINGRESP, 0, 0); + + int t = 0; + connect_frame.timestamp_ns = ++t; + connack_frame.timestamp_ns = ++t; + pingresp_frame.timestamp_ns = ++t; + + req_map[0].push_back(connect_frame); + + resp_map[0].push_back(connack_frame); + resp_map[0].push_back(pingresp_frame); + + result = StitchFrames(&req_map, &resp_map, &state); + + EXPECT_TRUE(AreAllDequesEmpty(resp_map)); + EXPECT_EQ(TotalDequeSize(req_map), 0); + EXPECT_EQ(result.error_count, 1); + EXPECT_EQ(result.records.size(), 1); +} + +TEST(MqttStitcherTest, InOrderMatching) { + absl::flat_hash_map> req_map; + absl::flat_hash_map> resp_map; + + RecordsWithErrorCount result; + StateWrapper state = {.global = {}, .send = {}, .recv = {}}; + + // Establishment of connection, ping requests and responses, and three publish requests (qos 1, + // qos 2 and qos 1) with increasing packet identifiers (since they are sent before their responses + // are received) + Message connect_frame, connack_frame, pingreq_frame, pingresp_frame, pub1_frame, pub1_pid3_frame, + puback_frame, puback_pid3_frame, pub2_pid2_frame, pubrec_pid2_frame, pubrel_pid2_frame, + pubcomp_pid2_frame, sub_frame, suback_frame, unsub_frame, unsuback_frame, auth_frame_server, + auth_frame_client; + connect_frame = CreateFrame(kRequest, MqttControlPacketType::CONNECT, 0, 0); + connack_frame = CreateFrame(kResponse, MqttControlPacketType::CONNACK, 0, 0); + pingreq_frame = CreateFrame(kRequest, MqttControlPacketType::PINGREQ, 0, 0); + pingresp_frame = CreateFrame(kResponse, MqttControlPacketType::PINGRESP, 0, 0); + pub1_frame = CreateFrame(kRequest, MqttControlPacketType::PUBLISH, 1, 1); + pub1_pid3_frame = CreateFrame(kRequest, MqttControlPacketType::PUBLISH, 3, 1); + pub2_pid2_frame = CreateFrame(kRequest, MqttControlPacketType::PUBLISH, 2, 2); + puback_frame = CreateFrame(kResponse, MqttControlPacketType::PUBACK, 1, 0); + puback_pid3_frame = CreateFrame(kResponse, MqttControlPacketType::PUBACK, 3, 0); + pubrec_pid2_frame = CreateFrame(kResponse, MqttControlPacketType::PUBREC, 2, 0); + pubrel_pid2_frame = CreateFrame(kRequest, MqttControlPacketType::PUBREL, 2, 0); + pubcomp_pid2_frame = CreateFrame(kResponse, MqttControlPacketType::PUBCOMP, 2, 0); + sub_frame = CreateFrame(kRequest, MqttControlPacketType::SUBSCRIBE, 0, 0); + suback_frame = CreateFrame(kResponse, MqttControlPacketType::SUBACK, 0, 0); + unsub_frame = CreateFrame(kRequest, MqttControlPacketType::UNSUBSCRIBE, 0, 0); + unsuback_frame = CreateFrame(kResponse, MqttControlPacketType::UNSUBACK, 0, 0); + + int t = 0; + connect_frame.timestamp_ns = ++t; + connack_frame.timestamp_ns = ++t; + pingreq_frame.timestamp_ns = ++t; + pingresp_frame.timestamp_ns = ++t; + sub_frame.timestamp_ns = ++t; + suback_frame.timestamp_ns = ++t; + pub1_frame.timestamp_ns = ++t; + pub2_pid2_frame.timestamp_ns = ++t; + pub1_pid3_frame.timestamp_ns = ++t; + puback_frame.timestamp_ns = ++t; + pubrec_pid2_frame.timestamp_ns = ++t; + puback_pid3_frame.timestamp_ns = ++t; + pubrel_pid2_frame.timestamp_ns = ++t; + unsub_frame.timestamp_ns = ++t; + pubcomp_pid2_frame.timestamp_ns = ++t; + unsuback_frame.timestamp_ns = ++t; + auth_frame_server.timestamp_ns = ++t; + auth_frame_client.timestamp_ns = ++t; + + req_map[1].push_back(connect_frame); + req_map[1].push_back(pingreq_frame); + req_map[1].push_back(sub_frame); + req_map[1].push_back(pub1_frame); + req_map[2].push_back(pub2_pid2_frame); + req_map[3].push_back(pub1_pid3_frame); + req_map[2].push_back(pubrel_pid2_frame); + req_map[2].push_back(unsub_frame); + + resp_map[1].push_back(connack_frame); + resp_map[1].push_back(pingresp_frame); + resp_map[1].push_back(suback_frame); + resp_map[1].push_back(puback_frame); + resp_map[2].push_back(pubrec_pid2_frame); + resp_map[3].push_back(puback_pid3_frame); + resp_map[2].push_back(pubcomp_pid2_frame); + resp_map[2].push_back(unsuback_frame); + + // Update the state for PUBLISH packet ID 1 and QOS 1, packet ID 2 and QOS 2, packet ID 3 and QOS + // 1 + state.send[std::tuple(1, 1)] = 0; + state.send[std::tuple(2, 2)] = 0; + state.send[std::tuple(2, 1)] = 0; + + result = StitchFrames(&req_map, &resp_map, &state); + + EXPECT_TRUE(AreAllDequesEmpty(resp_map)); + EXPECT_EQ(TotalDequeSize(req_map), 0); + EXPECT_EQ(result.error_count, 0); + EXPECT_EQ(result.records.size(), 8); +} + +TEST(MqttStitcherTest, OutOfOrderMatching) { + // Test for packet stitching for packets that do not have packet identifiers + absl::flat_hash_map> req_map; + absl::flat_hash_map> resp_map; + + RecordsWithErrorCount result; + StateWrapper state = {.global = {}, .send = {}, .recv = {}}; + + // Delayed response for PUBLISH (PID 3, QOS 1) and SUBSCRIBE (PID 4) (Delayed meaning some the + // response for this request comes after the responses for later requests) + Message pub1_pid_1_frame, pub1_pid3_frame, puback_pid1_frame, puback_pid3_frame, pub2_pid2_frame, + pubrec_pid2_frame, pubrel_pid2_frame, pubcomp_pid2_frame, sub_pid4_frame, suback_pid4_frame; + + pub1_pid_1_frame = CreateFrame(kRequest, MqttControlPacketType::PUBLISH, 1, 1); + pub1_pid3_frame = CreateFrame(kRequest, MqttControlPacketType::PUBLISH, 3, 1); + pub2_pid2_frame = CreateFrame(kRequest, MqttControlPacketType::PUBLISH, 2, 2); + puback_pid1_frame = CreateFrame(kResponse, MqttControlPacketType::PUBACK, 1, 0); + puback_pid3_frame = CreateFrame(kResponse, MqttControlPacketType::PUBACK, 3, 0); + pubrec_pid2_frame = CreateFrame(kResponse, MqttControlPacketType::PUBREC, 2, 0); + pubrel_pid2_frame = CreateFrame(kRequest, MqttControlPacketType::PUBREL, 2, 0); + pubcomp_pid2_frame = CreateFrame(kResponse, MqttControlPacketType::PUBCOMP, 2, 0); + sub_pid4_frame = CreateFrame(kRequest, MqttControlPacketType::SUBSCRIBE, 4, 0); + suback_pid4_frame = CreateFrame(kResponse, MqttControlPacketType::SUBACK, 4, 0); + + // Delayed responses with interleaved requests + int t = 0; + pub1_pid_1_frame.timestamp_ns = ++t; + pub2_pid2_frame.timestamp_ns = ++t; + pub1_pid3_frame.timestamp_ns = ++t; + sub_pid4_frame.timestamp_ns = ++t; + puback_pid1_frame.timestamp_ns = ++t; + pubrec_pid2_frame.timestamp_ns = ++t; + pubrel_pid2_frame.timestamp_ns = ++t; + pubcomp_pid2_frame.timestamp_ns = ++t; + puback_pid3_frame.timestamp_ns = ++t; + suback_pid4_frame.timestamp_ns = ++t; + + req_map[1].push_back(pub1_pid_1_frame); + req_map[2].push_back(pub2_pid2_frame); + req_map[3].push_back(pub1_pid3_frame); + req_map[2].push_back(pubrel_pid2_frame); + resp_map[1].push_back(puback_pid1_frame); + resp_map[2].push_back(pubrec_pid2_frame); + resp_map[2].push_back(pubcomp_pid2_frame); + resp_map[3].push_back(puback_pid3_frame); + + // Update the state for PUBLISH packet ID 1 and QOS 1, packet ID 2 and QOS 2, packet ID 3 and QOS + // 1 + state.send[std::tuple(1, 1)] = 0; + state.send[std::tuple(2, 2)] = 0; + state.send[std::tuple(3, 1)] = 0; + + result = StitchFrames(&req_map, &resp_map, &state); + EXPECT_TRUE(AreAllDequesEmpty(resp_map)); + EXPECT_EQ(TotalDequeSize(req_map), 0); + EXPECT_EQ(result.error_count, 0); + EXPECT_EQ(result.records.size(), 4); +} + +TEST(MqttStitcherTest, DummyResponseStitching) { + // Test for requests that do not have responses + absl::flat_hash_map> req_map; + absl::flat_hash_map> resp_map; + + RecordsWithErrorCount result; + StateWrapper state = {.global = {}, .send = {}, .recv = {}}; + + Message pub0_frame, disconnect_frame, connect_frame, connack_frame; + pub0_frame = CreateFrame(kRequest, MqttControlPacketType::PUBLISH, 0, 0); // PUBLISH with QoS 0 + disconnect_frame = CreateFrame(kRequest, MqttControlPacketType::DISCONNECT, 0, 0); // DISCONNECT + + int t = 0; + pub0_frame.timestamp_ns = ++t; + disconnect_frame.timestamp_ns = ++t; + + req_map[0].push_back(pub0_frame); + req_map[0].push_back(disconnect_frame); + + result = StitchFrames(&req_map, &resp_map, &state); + + EXPECT_EQ(TotalDequeSize(req_map), 0); + EXPECT_EQ(result.error_count, 0); + EXPECT_EQ(result.records.size(), 2); +} + +TEST(MqttStitcherTest, DuplicateAnsweredRequests) { + absl::flat_hash_map> req_map; + absl::flat_hash_map> resp_map; + + RecordsWithErrorCount result; + StateWrapper state; + state.send = {}; + state.recv = {}; + + Message pub1_frame_1, pub1_frame_2, pub2_frame_1, pub2_frame_2, puback_frame, pubrec_frame, + subscribe_frame, suback_frame; + pub1_frame_1 = CreateFrame(kRequest, MqttControlPacketType::PUBLISH, 1, 1); + pub1_frame_2 = CreateFrame(kRequest, MqttControlPacketType::PUBLISH, 1, 1); + pub1_frame_2.dup = true; + pub2_frame_1 = CreateFrame(kRequest, MqttControlPacketType::PUBLISH, 1, 2); + pub2_frame_2 = CreateFrame(kRequest, MqttControlPacketType::PUBLISH, 1, 2); + pub2_frame_2.dup = true; + puback_frame = CreateFrame(kResponse, MqttControlPacketType::PUBACK, 1, 0); + pubrec_frame = CreateFrame(kResponse, MqttControlPacketType::PUBREC, 1, 0); + + int t = 0; + pub1_frame_1.timestamp_ns = ++t; + pub1_frame_2.timestamp_ns = ++t; + puback_frame.timestamp_ns = ++t; + pub2_frame_1.timestamp_ns = ++t; + pub2_frame_2.timestamp_ns = ++t; + pubrec_frame.timestamp_ns = ++t; + + req_map[1].push_back(pub1_frame_1); + req_map[1].push_back(pub1_frame_2); + req_map[1].push_back(pub2_frame_1); + req_map[1].push_back(pub2_frame_2); + + resp_map[1].push_back(puback_frame); + resp_map[1].push_back(pubrec_frame); + + // Update the state for PUBLISH packet ID 1 and QOS 1, packet ID 2 and QOS 2, packet ID 1 and QOS + // 2 + state.send[std::tuple(1, 1)] = 1; + state.send[std::tuple(1, 2)] = 1; + + result = StitchFrames(&req_map, &resp_map, &state); + EXPECT_EQ(TotalDequeSize(req_map), 0); + EXPECT_EQ(result.error_count, 0); + EXPECT_EQ(result.records.size(), 2); +} + +TEST(MqttStitcherTest, DuplicateUnansweredRequests) { + absl::flat_hash_map> req_map; + absl::flat_hash_map> resp_map; + + RecordsWithErrorCount result; + StateWrapper state; + state.send = {}; + state.recv = {}; + + Message pub1_frame_1, pub1_frame_2, pub2_frame_1, pub2_frame_2, puback_frame, pubrec_frame, + subscribe_frame, suback_frame; + pub1_frame_1 = CreateFrame(kRequest, MqttControlPacketType::PUBLISH, 1, 1); + pub1_frame_2 = CreateFrame(kRequest, MqttControlPacketType::PUBLISH, 1, 1); + pub1_frame_2.dup = true; + pub2_frame_1 = CreateFrame(kRequest, MqttControlPacketType::PUBLISH, 2, 2); + pub2_frame_2 = CreateFrame(kRequest, MqttControlPacketType::PUBLISH, 2, 2); + pub2_frame_2.dup = true; + + // Unanswered duplicate PUBLISH (QOS 1) + int t = 0; + pub1_frame_1.timestamp_ns = ++t; + pub1_frame_2.timestamp_ns = ++t; + pub2_frame_1.timestamp_ns = ++t; + pub2_frame_2.timestamp_ns = ++t; + + req_map[1].push_back(pub1_frame_1); + req_map[2].push_back(pub1_frame_2); + req_map[2].push_back(pub2_frame_1); + req_map[2].push_back(pub2_frame_2); + + state.send[std::tuple(1, 1)] = 1; + state.send[std::tuple(2, 2)] = 1; + + result = StitchFrames(&req_map, &resp_map, &state); + + EXPECT_TRUE(AreAllDequesEmpty(resp_map)); + EXPECT_EQ(TotalDequeSize(req_map), 4); + EXPECT_EQ(result.error_count, 0); + EXPECT_EQ(result.records.size(), 0); +} + +} // namespace mqtt +} // namespace protocols +} // namespace stirling +} // namespace px diff --git a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/test_utils.h b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/test_utils.h new file mode 100644 index 00000000000..34d19b997fd --- /dev/null +++ b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/test_utils.h @@ -0,0 +1,50 @@ +/* + * Copyright 2018- The Pixie Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include + +#include "src/common/base/utils.h" +#include "src/stirling/source_connectors/socket_tracer/protocols/mqtt/types.h" + +namespace px { +namespace stirling { +namespace protocols { +namespace mqtt { +namespace testutils { + +inline Message CreateFrame(message_type_t type, const MqttControlPacketType control_packet_type, + uint32_t packet_identifier, uint32_t qos) { + Message f; + + f.type = type; + f.control_packet_type = magic_enum::enum_integer(control_packet_type); + f.header_fields["packet_identifier"] = packet_identifier; + f.header_fields["qos"] = qos; + + return f; +} + +} // namespace testutils +} // namespace mqtt +} // namespace protocols +} // namespace stirling +} // namespace px diff --git a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/types.h b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/types.h index 912d2ba1d22..69f02c4d8d3 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/types.h +++ b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/types.h @@ -20,6 +20,7 @@ #include #include +#include #include "src/common/base/utils.h" #include "src/common/json/json.h" @@ -35,13 +36,35 @@ using ::px::utils::ToJSONString; // The protocol specification : https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.pdf // This supports MQTTv5 +// This is modeling a 4 bit field specifying the control packet type +enum class MqttControlPacketType : uint8_t { + CONNECT = 1, + CONNACK = 2, + PUBLISH = 3, + PUBACK = 4, + PUBREC = 5, + PUBREL = 6, + PUBCOMP = 7, + SUBSCRIBE = 8, + SUBACK = 9, + UNSUBSCRIBE = 10, + UNSUBACK = 11, + PINGREQ = 12, + PINGRESP = 13, + DISCONNECT = 14, + AUTH = 15 +}; + +using packet_id_t = uint16_t; struct Message : public FrameBase { message_type_t type = message_type_t::kUnknown; + // This is modeling a 4 bit field specifying the control packet type uint8_t control_packet_type = 0xff; - bool dup; - bool retain; + bool dup = false; + bool retain = false; + bool consumed = false; std::map header_fields; std::map properties, payload; @@ -73,13 +96,24 @@ struct Record { } }; +struct StateWrapper { + std::monostate global; + std::map, uint32_t> send; // Client side PUBLISHes + std::map, uint32_t> recv; // Server side PUBLISHes +}; + struct ProtocolTraits : public BaseProtocolTraits { using frame_type = Message; using record_type = Record; - using state_type = NoState; + using state_type = StateWrapper; + using key_type = packet_id_t; + static constexpr StreamSupport stream_support = BaseProtocolTraits::UseStream; }; } // namespace mqtt + +template <> +mqtt::packet_id_t GetStreamID(mqtt::Message* message); } // namespace protocols } // namespace stirling } // namespace px