Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a self-check test for processing a message which exceeds nanopb field size limits #589

Merged
merged 1 commit into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ set(checks_SRC
"src/checks/check_sign_tx.c"
"src/checks/conv_checks.c"
"src/checks/misc_checks.c"
"src/checks/rpc_checks.c"
ivmaykov marked this conversation as resolved.
Show resolved Hide resolved
"src/checks/self_checks.c"
"src/checks/validate_fees.c"
"src/checks/verify_mix_entropy.c"
Expand Down
9 changes: 9 additions & 0 deletions core/include/checks.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ int verify_no_rollback(void);
int verify_check_qrsignature_pub(void);
int verify_conv_btc_to_satoshi(void);

/**
* Verifies that calling handle_incoming_message() with a serialized protobuf
* message which exceeds nanopb field size limits (defined in proto .options
* files) fails as expected.
*
* Note that as this function uses statically-allocated buffers, it is not thread-safe.
*/
int verify_rpc_oversized_message_rejected(void);
ivmaykov marked this conversation as resolved.
Show resolved Hide resolved

#define ASSERT_STR_EQUAL(value, expecting, message) \
do { \
if (strcmp(expecting, value) != 0) { \
Expand Down
281 changes: 281 additions & 0 deletions core/src/checks/rpc_checks.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
#include "checks.h"
#include "config.h"
#include "log.h"
#include "memzero.h"
#include "nanopb_stream.h"
#include "print.h"
#include "rand.h"
#include "rpc.h"

#include <assert.h>
#include <pb_decode.h>
#include <pb_encode.h>
#include <squareup/subzero/internal.pb.h>

ivmaykov marked this conversation as resolved.
Show resolved Hide resolved
// Helper which returns the size of a buffer that would be needed to hold the serialized version of the given
// protobuf structure, assuming that pb_encode_delimited() serialization will be used.
static size_t get_serialized_proto_struct_size(const pb_field_t fields[], const void* const proto_struct) {
pb_ostream_t stream = PB_OSTREAM_SIZING;
if (!pb_encode_delimited(&stream, fields, proto_struct)) {
ERROR("%s: pb_encode_delimited() failed: %s", __func__, PB_GET_ERROR(&stream));
ivmaykov marked this conversation as resolved.
Show resolved Hide resolved
return 0;
}
return stream.bytes_written;
}

ivmaykov marked this conversation as resolved.
Show resolved Hide resolved
// Serializes the given protobuf structure to the given buffer of the given size, using pb_encode_delimited().
// Returns true on success or false on failure.
// Caller brings their own buffer memory, this function does not allocate.
static bool serialize_proto_struct_to_buffer(
pb_byte_t* const buffer,
const size_t buffer_size,
const pb_field_t fields[],
const void* const proto_struct) {
pb_ostream_t ostream = pb_ostream_from_buffer(buffer, buffer_size);
if (!pb_encode_delimited(&ostream, fields, proto_struct)) {
ERROR("%s: pb_encode_delimited() failed: %s", __func__, PB_GET_ERROR(&ostream));
return false;
}
return true;
}

// Deserializes the given buffer into the given protobuf structure, using pb_decode_delimited().
// Returns true on success or false on failure.
// Caller brings their own protobuf structure, this function does not allocate.
static bool deserialize_proto_struct_from_buffer(
const pb_byte_t* const buffer,
const size_t buffer_size,
const pb_field_t fields[],
void* const proto_struct) {
pb_istream_t istream = pb_istream_from_buffer(buffer, buffer_size);
if (!pb_decode_delimited(&istream, fields, proto_struct)) {
ERROR("%s: pb_decode_delimited() failed: %s", __func__, PB_GET_ERROR(&istream));
return false;
}
return true;
}

/**
* Helper for checking assumptions in the gnarly protobuf mangling code in verify_rpc_oversized_message_rejected().
*/
static bool check_byte_equals(const char* parent_func, const pb_byte_t* const buf, size_t idx, pb_byte_t expected_val) {
const pb_byte_t actual_val = buf[idx];
if (actual_val != expected_val) {
ERROR(
"%s: buf[%zu] contains an unexpected value: %hhu, expected: %hhu", parent_func, idx, actual_val, expected_val);
return false;
}
return true;
}

static pb_byte_t request_buffer[256] = { 0 };
static pb_byte_t response_buffer[256] = { 0 };

int verify_rpc_oversized_message_rejected(void) {
int result = 0;

// Construct an initial InternalCommandRequest which holds an InitWallet command
// with a maximum-allowed-length random_bytes field.
InternalCommandRequest cmd = InternalCommandRequest_init_default;
cmd.version = VERSION;
cmd.wallet_id = 1; // dummy value
cmd.which_command = InternalCommandRequest_InitWallet_tag;
static_assert(
ivmaykov marked this conversation as resolved.
Show resolved Hide resolved
sizeof(cmd.command.InitWallet.random_bytes.bytes) == MASTER_SEED_SIZE,
"MASTER_SEED_SIZE must equal sizeof(cmd.command.InitWallet.random_bytes.bytes)");
cmd.command.InitWallet.random_bytes.size = MASTER_SEED_SIZE;
random_buffer(cmd.command.InitWallet.random_bytes.bytes, MASTER_SEED_SIZE);

// Compute the size of the serialized struct.
size_t serialized_size = get_serialized_proto_struct_size(InternalCommandRequest_fields, &cmd);
if (serialized_size == 0) {
ERROR("%s: error computing serialized request size", __func__);
result = -1;
goto out;
}

if (serialized_size + 1 > sizeof(request_buffer)) {
ERROR(
"%s: sizeof(request_buffer) == %zu but needs to be at least %zu. Modify the code and rebuild.",
__func__,
sizeof(request_buffer),
serialized_size + 1);
result = -1;
goto out;
}

// Serialize the request struct into request_buffer.
if (!serialize_proto_struct_to_buffer(request_buffer, sizeof(request_buffer), InternalCommandRequest_fields, &cmd)) {
ERROR("%s: serialize_proto_struct_to_buffer() failed", __func__);
result = -1;
goto out;
}

// Corrupt the message by making the random_bytes field 1 byte longer than the max allowed size.
ivmaykov marked this conversation as resolved.
Show resolved Hide resolved
// Note that this is a bit fragile and could break if the protobuf definitions inside
// internal.proto are changed. But if that happens, hopefully this test breaks immediately
// and can be fixed. Understanding of low-level protobuf serialization is recommended, see
// https://protobuf.dev/programming-guides/encoding/ for the details (it's not that bad).
// Basically:
ivmaykov marked this conversation as resolved.
Show resolved Hide resolved
// serialized_request[0] - varint-encoded leading LEN byte. This is not normally there for binary
// encoded protobufs, but it's added by nanopb because we are using
// pb_encode_delimited(). If the message is longer than 127 bytes, this
// length will actually take more than 1 byte, shifting everything after
// it by a byte.
// *** NOTE: WE NEED TO INCREMENT THIS BY 1. ***
ivmaykov marked this conversation as resolved.
Show resolved Hide resolved
// serialized_request[1] - field id (1 << 3) + tag (0) for field 1 (version). Should equal 8.
// serialized_request[2..3] - varint-encoded value for field 1. Leave this alone, it's the
// contents of the 'version' field (210 at the time of writing). If
// version ever exceeds 16383, this will start taking up an extra byte
// and shift everything after it by a byte.
// serialized_request[4] - field id (2 << 3) + tag (0) for field 2 (wallet_id). Should equal 16.
// serialized_request[5] - varint-encoded value for field 2. Leave this alone, it's the dummy
// 'wallet' field which we set to 1 above. Should equal 1.
// serialized_request[6] - field id (5 << 3) + tag (2, for 'LEN') for field 5 (command.InitWallet).
// Should equal 42.
// serialized_request[7] - varint-encoded LEN of the InitWalletRequest submessage.
// Should equal 66.
// *** NOTE: WE NEED TO INCREMENT THIS BY 1. ***
// serialized_request[8] - field id (1 << 3) + tag (2, for 'LEN') for field 1 of sub-message.
// Should equal 10.
// serialized_request[9] - varint-encoded LEN of field 1 (random_bytes) of sub-message.
// Should equal 64.
// *** NOTE: WE NEED TO INCREMENT THIS BY 1. ***
// serialized_request[10..73] - the contents of the random_bytes field. Should be 64 bytes in length.
// serialized_request[74] - doesn't exist in the original message. We add an extra data byte here.
// It can be any value, we arbitrarily choose 0xaa.
//
// Let's check the above assumptions to make sure they are correct before proceeding:
if (!check_byte_equals(__func__, request_buffer, 0, (pb_byte_t) 73)) {
result = -1;
goto out;
}
if (!check_byte_equals(__func__, request_buffer, 0, (pb_byte_t) (serialized_size - 1))) {
result = -1;
goto out;
}
if (!check_byte_equals(__func__, request_buffer, 1, (pb_byte_t) ((1 << 3) + 0))) {
result = -1;
goto out;
}
// The 'cmd.version' field is varint-encoded into 2 little-endian bytes:
// ... First byte contains least-significant 7 bits + highest bit set to indicate that there's more data.
if (!check_byte_equals(__func__, request_buffer, 2, (pb_byte_t) ((cmd.version & 0x7f) | 0x80))) {
result = -1;
goto out;
}
// ... Second byte contains the next 1-7 bits + highest bit unset to indicate that there's no more data.
if (!check_byte_equals(__func__, request_buffer, 3, (pb_byte_t) (cmd.version >> 7))) {
result = -1;
goto out;
}
if (!check_byte_equals(__func__, request_buffer, 4, (pb_byte_t) ((2 << 3) + 0))) {
result = -1;
goto out;
}
if (!check_byte_equals(__func__, request_buffer, 5, (pb_byte_t) cmd.wallet_id)) {
result = -1;
goto out;
}
if (!check_byte_equals(__func__, request_buffer, 6, (pb_byte_t) ((5 << 3) + 2))) {
result = -1;
goto out;
}
if (!check_byte_equals(__func__, request_buffer, 7, (pb_byte_t) (MASTER_SEED_SIZE + 2))) {
result = -1;
goto out;
}
if (!check_byte_equals(__func__, request_buffer, 8, (pb_byte_t) ((1 << 3) + 2))) {
result = -1;
goto out;
}
if (!check_byte_equals(__func__, request_buffer, 9, (pb_byte_t) MASTER_SEED_SIZE)) {
result = -1;
goto out;
}

request_buffer[0]++; // increment leading LEN byte
request_buffer[7]++; // increment LEN byte for top-level field 5
request_buffer[9]++; // increment LEN byte for nested field 1
request_buffer[serialized_size] = 0xaa; // set the last byte to an arbitrary value
serialized_size++; // increment serialized_size since we added a byte of data

// Create a stream which will read from the corrupted serialized buffer.
pb_istream_t istream = pb_istream_from_buffer(request_buffer, serialized_size);
// Create a stream which will write to the response buffer.
pb_ostream_t ostream = pb_ostream_from_buffer(response_buffer, sizeof(response_buffer));

// Now that we have a serialized buffer, try to pass it to handle_incoming_message().
// This should fail because the InitWallet.random_bytes field has a length of 65 bytes,
// but nanopb options set the limit for this field at 64 bytes.
//
// NOTE: when building for nCipher, there are command hooks that would reject the command
// because it's missing the tickets for key use authorization. But this doesn't matter for
// this test case, because the protobuf parsing happens before that and fails first.
ERROR("(next line is expected to show red text...)");

handle_incoming_message(&istream, &ostream); // <---- this is the actual function under test

// Extract the response structure from the serialized_response buffer. It should be an error.
const size_t response_size = ostream.bytes_written;
if (response_size == 0) {
ERROR("%s: no response received from handle_incoming_message(): %s", __func__, PB_GET_ERROR(&ostream));
result = -1;
goto out;
}

// note: no need to initialize the response, static bool deserialize_proto_struct_from_buffer() does it via
// pb_decode_delimited().
InternalCommandResponse response;
if (!deserialize_proto_struct_from_buffer(
response_buffer, response_size, InternalCommandResponse_fields, &response)) {
ERROR("%s: deserialize_proto_struct_from_buffer() failed", __func__);
result = -1;
goto out;
}

// Check that the response contains an error.
if (response.which_response != InternalCommandResponse_Error_tag) {
ERROR(
"%s: wrong response tag: %d, expected: %d",
__func__,
(int) response.which_response,
(int) InternalCommandResponse_Error_tag);
result = -1;
goto out;
}

// Check that the error response contains the expected error code.
if (response.response.Error.code != Result_COMMAND_DECODE_FAILED) {
ERROR(
"%s: wrong response error code: %d, expected: %d",
__func__,
(int) response.response.Error.code,
(int) Result_COMMAND_DECODE_FAILED);
result = -1;
goto out;
}

// Check that the error response contains some message.
if (!response.response.Error.has_message) {
ERROR("%s: error response does not contain a 'message' field", __func__);
result = -1;
goto out;
}

// Check that the error response contains the expected message.
if (0 != strcmp("Decode Input failed: bytes overflow", response.response.Error.message)) {
ERROR("%s: error response contains unexpected message: %s", __func__, response.response.Error.message);
result = -1;
goto out;
}

out:
memzero(request_buffer, sizeof(request_buffer));
ivmaykov marked this conversation as resolved.
Show resolved Hide resolved
memzero(response_buffer, sizeof(response_buffer));

if (result == 0) {
INFO("%s: ok", __func__);
}
return result;
}
6 changes: 6 additions & 0 deletions core/src/checks/self_checks.c
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ int run_self_checks(void) {
ERROR("self check failure: verify_conv_btc_to_satoshi failed.");
}

t = verify_rpc_oversized_message_rejected();
ivmaykov marked this conversation as resolved.
Show resolved Hide resolved
if (t != 0) {
r = -1;
ERROR("self check failure: verify_rpc_oversized_message_rejected failed.");
}

// environment specific additional checks + cleanup
t = post_run_self_checks();
if (t != 0) {
Expand Down