Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable support for MQTT stitcher in stirling #1918

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,11 @@ pl_cc_test(
":cc_library",
],
)

pl_cc_test(
name = "stitcher_test",
srcs = ["stitcher_test.cc"],
deps = [
":cc_library",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <iomanip>
#include <map>
#include <string_view>
#include <tuple>

#include "src/common/base/base.h"
#include "src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse.h"
Expand All @@ -37,24 +38,6 @@ namespace protocols {

namespace mqtt {

enum class MqttControlPacketType : uint8_t {
CONNECT = 1,
CONNACK = 2,
PUBLISH = 3,
PUBACK = 4,
PUBREC = 5,
PUBREL = 6,
PUBCOMP = 7,
SUBSCRIBE = 8,
SUBACK = 9,
UNSUBSCRIBE = 10,
UNSUBACK = 11,
PINGREQ = 12,
PINGRESP = 13,
DISCONNECT = 14,
AUTH = 15
};

enum class PropertyCode : uint8_t {
PayloadFormatIndicator = 0x01,
MessageExpiryInterval = 0x02,
Expand Down Expand Up @@ -653,7 +636,8 @@ ParseState ParsePayload(Message* result, BinaryDecoder* decoder,
}
}

ParseState ParseFrame(message_type_t type, std::string_view* buf, Message* result) {
ParseState ParseFrame(message_type_t type, std::string_view* buf, Message* result,
mqtt::StateWrapper* state) {
CTX_DCHECK(type == message_type_t::kRequest || type == message_type_t::kResponse);
if (buf->size() < 2) {
return ParseState::kNeedsMoreData;
Expand Down Expand Up @@ -723,6 +707,19 @@ ParseState ParseFrame(message_type_t type, std::string_view* buf, Message* resul
return ParseState::kInvalid;
}

// Updating the state for PUBLISH based on whether it is duplicate
if (control_packet_type == MqttControlPacketType::PUBLISH) {
if (result->dup) {
if (type == message_type_t::kRequest) {
state->send[std::tuple<uint32_t, uint32_t>(result->header_fields["packet_identifier"],
result->header_fields["qos"])] += 1;
} else {
state->recv[std::tuple<uint32_t, uint32_t>(result->header_fields["packet_identifier"],
result->header_fields["qos"])] += 1;
}
}
}

if (ParsePayload(result, &decoder, control_packet_type) == ParseState::kInvalid) {
return ParseState::kInvalid;
}
Expand All @@ -735,8 +732,8 @@ ParseState ParseFrame(message_type_t type, std::string_view* buf, Message* resul

template <>
ParseState ParseFrame(message_type_t type, std::string_view* buf, mqtt::Message* result,
NoState* /*state*/) {
return mqtt::ParseFrame(type, buf, result);
mqtt::StateWrapper* state) {
return mqtt::ParseFrame(type, buf, result, state);
}

template <>
Expand All @@ -745,6 +742,11 @@ size_t FindFrameBoundary<mqtt::Message>(message_type_t /*type*/, std::string_vie
return start_pos + buf.length();
}

template <>
mqtt::packet_id_t GetStreamID(mqtt::Message* message) {
return message->header_fields["packet_identifier"];
}

} // namespace protocols
} // namespace stirling
} // namespace px
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,14 @@ namespace protocols {

template <>
ParseState ParseFrame(message_type_t type, std::string_view* buf, mqtt::Message* frame,
NoState* state);
mqtt::StateWrapper* state);

template <>
size_t FindFrameBoundary<mqtt::Message>(message_type_t type, std::string_view buf, size_t start_pos,
NoState* state);
mqtt::StateWrapper* state);

template <>
mqtt::packet_id_t GetStreamID(mqtt::Message* message);

} // namespace protocols
} // namespace stirling
Expand Down
Loading
Loading