Skip to content

Commit

Permalink
squashme: more fixes following review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
ivmaykov committed Apr 24, 2023
1 parent d1c3725 commit 87967ed
Showing 1 changed file with 88 additions and 64 deletions.
152 changes: 88 additions & 64 deletions core/src/checks/rpc_checks.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "checks.h"
#include "config.h"
#include "log.h"
#include "memzero.h"
#include "nanopb_stream.h"
#include "print.h"
#include "rand.h"
Expand All @@ -25,7 +26,7 @@ static size_t get_serialized_proto_struct_size(const pb_field_t fields[], const
// Serializes the given protobuf structure to the given buffer of the given size, using pb_encode_delimited().
// Returns true on success or false on failure.
// Caller brings their own buffer memory, this function does not allocate.
static bool serialize_to_buffer(
static bool serialize_proto_struct_to_buffer(
pb_byte_t* const buffer,
const size_t buffer_size,
const pb_field_t fields[],
Expand All @@ -41,7 +42,7 @@ static bool serialize_to_buffer(
// Deserializes the given buffer into the given protobuf structure, using pb_decode_delimited().
// Returns true on success or false on failure.
// Caller brings their own protobuf structure, this function does not allocate.
static bool deserialize_from_buffer(
static bool deserialize_proto_struct_from_buffer(
const pb_byte_t* const buffer,
const size_t buffer_size,
const pb_field_t fields[],
Expand All @@ -54,10 +55,24 @@ static bool deserialize_from_buffer(
return true;
}

/**
* Helper for checking assumptions in the gnarly protobuf mangling code in verify_rpc_oversized_message_rejected().
*/
static bool check_byte_equals(const char* parent_func, const pb_byte_t* const buf, size_t idx, pb_byte_t expected_val) {
const pb_byte_t actual_val = buf[idx];
if (actual_val != expected_val) {
ERROR(
"%s: buf[%zu] contains an unexpected value: %hhu, expected: %hhu", parent_func, idx, actual_val, expected_val);
return false;
}
return true;
}

static pb_byte_t request_buffer[256] = { 0 };
static pb_byte_t response_buffer[256] = { 0 };

int verify_rpc_oversized_message_rejected(void) {
int result = 0;
pb_byte_t* serialized_request = NULL;
pb_byte_t* serialized_response = NULL;

// Construct an initial InternalCommandRequest which holds an InitWallet command
// with a maximum-allowed-length random_bytes field.
Expand All @@ -79,41 +94,23 @@ int verify_rpc_oversized_message_rejected(void) {
goto out;
}

// Allocate a buffer to hold the serialized struct.
// Note that we allocate 1 extra byte because we'll be extending the message.
serialized_request = (pb_byte_t*) calloc(1, serialized_size + 1);
if (NULL == serialized_request) {
ERROR("%s: calloc(1, %zu) failed", __func__, serialized_size + 1);
if (serialized_size + 1 > sizeof(request_buffer)) {
ERROR(
"%s: sizeof(request_buffer) == %zu but needs to be at least %zu. Modify the code and rebuild.",
__func__,
sizeof(request_buffer),
serialized_size + 1);
result = -1;
goto out;
}

// Serialize the struct to a byte array.
if (!serialize_to_buffer(serialized_request, serialized_size, InternalCommandRequest_fields, &cmd)) {
ERROR("%s: serialize_to_buf() failed", __func__);
// Serialize the request struct into request_buffer.
if (!serialize_proto_struct_to_buffer(request_buffer, sizeof(request_buffer), InternalCommandRequest_fields, &cmd)) {
ERROR("%s: serialize_proto_struct_to_buffer() failed", __func__);
result = -1;
goto out;
}

// Helper macro used to check our assumptions in the gnarly protobuf mangling code below
#define ASSERT_BYTE_EQUALS(buf, idx, expected_val) \
do { \
const pb_byte_t* buf_ = (buf); \
const size_t idx_ = (idx); \
const pb_byte_t expected_val_ = (expected_val); \
const pb_byte_t actual_val_ = buf_[idx_]; \
if (actual_val_ != expected_val_) { \
ERROR( \
"%s: buf[%zu] contains an unexpected value: %hhu, expected: %hhu", \
__func__, \
idx_, \
actual_val_, \
expected_val_); \
result = -1; \
goto out; \
} \
} while (0)

// Corrupt the message by making the random_bytes field 1 byte longer than the max allowed size.
// Note that this is a bit fragile and could break if the protobuf definitions inside
// internal.proto are changed. But if that happens, hopefully this test breaks immediately
Expand Down Expand Up @@ -149,41 +146,64 @@ int verify_rpc_oversized_message_rejected(void) {
// It can be any value, we arbitrarily choose 0xaa.
//
// Let's check the above assumptions to make sure they are correct before proceeding:
ASSERT_BYTE_EQUALS(serialized_request, 0, (pb_byte_t) 73);
ASSERT_BYTE_EQUALS(serialized_request, 0, (pb_byte_t) (serialized_size - 1));
ASSERT_BYTE_EQUALS(serialized_request, 1, (pb_byte_t) ((1 << 3) + 0));
if (!check_byte_equals(__func__, request_buffer, 0, (pb_byte_t) 73)) {
result = -1;
goto out;
}
if (!check_byte_equals(__func__, request_buffer, 0, (pb_byte_t) (serialized_size - 1))) {
result = -1;
goto out;
}
if (!check_byte_equals(__func__, request_buffer, 1, (pb_byte_t) ((1 << 3) + 0))) {
result = -1;
goto out;
}
// The 'cmd.version' field is varint-encoded into 2 little-endian bytes:
// ... First byte contains least-significant 7 bits + highest bit set to indicate that there's more data.
ASSERT_BYTE_EQUALS(serialized_request, 2, (pb_byte_t) ((cmd.version & 0x7f) | 0x80));
if (!check_byte_equals(__func__, request_buffer, 2, (pb_byte_t) ((cmd.version & 0x7f) | 0x80))) {
result = -1;
goto out;
}
// ... Second byte contains the next 1-7 bits + highest bit unset to indicate that there's no more data.
ASSERT_BYTE_EQUALS(serialized_request, 3, (pb_byte_t) (cmd.version >> 7));
ASSERT_BYTE_EQUALS(serialized_request, 4, (pb_byte_t) ((2 << 3) + 0));
ASSERT_BYTE_EQUALS(serialized_request, 5, (pb_byte_t) cmd.wallet_id);
ASSERT_BYTE_EQUALS(serialized_request, 6, (pb_byte_t) ((5 << 3) + 2));
ASSERT_BYTE_EQUALS(serialized_request, 7, (pb_byte_t) (MASTER_SEED_SIZE + 2));
ASSERT_BYTE_EQUALS(serialized_request, 8, (pb_byte_t) ((1 << 3) + 2));
ASSERT_BYTE_EQUALS(serialized_request, 9, (pb_byte_t) MASTER_SEED_SIZE);
#undef ASSERT_BYTE_EQUALS

serialized_request[0]++; // increment leading LEN byte
serialized_request[7]++; // increment LEN byte for top-level field 5
serialized_request[9]++; // increment LEN byte for nested field 1
serialized_request[serialized_size] = 0xaa; // set the last byte to an arbitrary value
serialized_size++; // increment serialized_size since we added a byte of data

// Allocate a buffer for the serialized response.
const size_t response_buffer_size = 2048; // 2048 bytes should be more than enough
serialized_response = (pb_byte_t*) calloc(1, response_buffer_size);
if (NULL == serialized_response) {
ERROR("%s: calloc(1, %zu) failed", __func__, response_buffer_size);
if (!check_byte_equals(__func__, request_buffer, 3, (pb_byte_t) (cmd.version >> 7))) {
result = -1;
goto out;
}
if (!check_byte_equals(__func__, request_buffer, 4, (pb_byte_t) ((2 << 3) + 0))) {
result = -1;
goto out;
}
if (!check_byte_equals(__func__, request_buffer, 5, (pb_byte_t) cmd.wallet_id)) {
result = -1;
goto out;
}
if (!check_byte_equals(__func__, request_buffer, 6, (pb_byte_t) ((5 << 3) + 2))) {
result = -1;
goto out;
}
if (!check_byte_equals(__func__, request_buffer, 7, (pb_byte_t) (MASTER_SEED_SIZE + 2))) {
result = -1;
goto out;
}
if (!check_byte_equals(__func__, request_buffer, 8, (pb_byte_t) ((1 << 3) + 2))) {
result = -1;
goto out;
}
if (!check_byte_equals(__func__, request_buffer, 9, (pb_byte_t) MASTER_SEED_SIZE)) {
result = -1;
goto out;
}

request_buffer[0]++; // increment leading LEN byte
request_buffer[7]++; // increment LEN byte for top-level field 5
request_buffer[9]++; // increment LEN byte for nested field 1
request_buffer[serialized_size] = 0xaa; // set the last byte to an arbitrary value
serialized_size++; // increment serialized_size since we added a byte of data

// Create a stream which will read from the corrupted serialized buffer.
pb_istream_t istream = pb_istream_from_buffer(serialized_request, serialized_size);
pb_istream_t istream = pb_istream_from_buffer(request_buffer, serialized_size);
// Create a stream which will write to the response buffer.
pb_ostream_t ostream = pb_ostream_from_buffer(serialized_response, response_buffer_size);
pb_ostream_t ostream = pb_ostream_from_buffer(response_buffer, sizeof(response_buffer));

// Now that we have a serialized buffer, try to pass it to handle_incoming_message().
// This should fail because the InitWallet.random_bytes field has a length of 65 bytes,
Expand All @@ -197,16 +217,19 @@ int verify_rpc_oversized_message_rejected(void) {
handle_incoming_message(&istream, &ostream); // <---- this is the actual function under test

// Extract the response structure from the serialized_response buffer. It should be an error.
const size_t actual_response_size = ostream.bytes_written;
if (actual_response_size == 0) {
const size_t response_size = ostream.bytes_written;
if (response_size == 0) {
ERROR("%s: no response received from handle_incoming_message(): %s", __func__, PB_GET_ERROR(&ostream));
result = -1;
goto out;
}

InternalCommandResponse response; // note: no need to init, deserialize_from_buf() does it via pb_decode_delimited().
if (!deserialize_from_buffer(serialized_response, actual_response_size, InternalCommandResponse_fields, &response)) {
ERROR("%s: deserialize_from_buf() failed", __func__);
// note: no need to initialize the response, static bool deserialize_proto_struct_from_buffer() does it via
// pb_decode_delimited().
InternalCommandResponse response;
if (!deserialize_proto_struct_from_buffer(
response_buffer, response_size, InternalCommandResponse_fields, &response)) {
ERROR("%s: deserialize_proto_struct_from_buffer() failed", __func__);
result = -1;
goto out;
}
Expand Down Expand Up @@ -248,8 +271,9 @@ int verify_rpc_oversized_message_rejected(void) {
}

out:
free(serialized_request);
free(serialized_response);
memzero(request_buffer, sizeof(request_buffer));
memzero(response_buffer, sizeof(response_buffer));

if (result == 0) {
INFO("%s: ok", __func__);
}
Expand Down

0 comments on commit 87967ed

Please sign in to comment.