diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8098c43..15323eb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -8,12 +8,40 @@ jobs: strategy: fail-fast: false matrix: - platform: [ubuntu-latest, macos-latest] - runs-on: ${{ matrix.platform }} + platform: + - runner: ubuntu-latest + sonarcloud-build-wrapper : build-wrapper-linux-x86-64 + - runner: ubuntu-20.04 + sonarcloud-build-wrapper : build-wrapper-linux-x86-64 + - runner: macos-latest + sonarcloud-build-wrapper : build-wrapper-macosx-x86 + runs-on: ${{ matrix.platform.runner }} steps: - uses: actions/checkout@v3 + + - name: Install gcovr + run: pip3 install gcovr + + - name: Setup sonar cloud + uses: SonarSource/sonarcloud-github-c-cpp@v2.0.2 + + - name: Build tests + run: cmake -B build -S . + - name: Build tests - run: mkdir build && cd build && cmake .. && make tests - - name: Run tests - run: ./tests/tests - working-directory: build + run: ${{ matrix.platform.sonarcloud-build-wrapper }} --out-dir bw-output cmake --build build + + - name: Run tests and coverage + run: ./build/tests/tests + + - name: Generate coverage report + run: gcovr --sonarqube -o coverage.xml build + + - name: Run sonar-scanner + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + SONAR_TOKEN: ${{ secrets.AUTERION_CI_SONAR_TOKEN }} + run: | + sonar-scanner \ + --define sonar.cfamily.build-wrapper-output="bw-output" \ + --define sonar.coverageReportPaths=coverage.xml diff --git a/.gitignore b/.gitignore index ea3a169..731fece 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ cmake-build-debug build Testing +html diff --git a/.vscode/c_cpp_properties.json b/.vscode/c_cpp_properties.json new file mode 100644 index 0000000..6006e92 --- /dev/null +++ b/.vscode/c_cpp_properties.json @@ -0,0 +1,17 @@ +{ + "configurations": [ + { + "name": "Linux", + "includePath": [ + "${workspaceFolder}/**" + ], + "defines": [], + "compilerPath": "/usr/bin/gcc", + "cStandard": "c11", + "cppStandard": "c++17", + "intelliSenseMode": "gcc-x64", + "compileCommands": "${workspaceFolder}/build/compile_commands.json" + } + ], + "version": 4 +} diff --git a/CMakeLists.txt b/CMakeLists.txt index 7b2c3c1..9dcefef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,6 +6,13 @@ include(GNUInstallDirs) set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) +find_program(CCACHE_PROGRAM ccache) +if(CCACHE_PROGRAM) + set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CCACHE_PROGRAM}") +endif() + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + set(CMAKE_CXX_STANDARD 17) add_library(mav INTERFACE) diff --git a/README.md b/README.md index 35f1025..916495b 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,22 @@ Since the library is header only, you only need the library on the build system. You can also include the library as a submodule in your project. +### Running the tests + +Libmav uses [doctest](https://github.com/doctest/doctest/) and [gcovr](https://github.com/gcovr/gcovr/). + +To run the tests, build the library, then run the test executable. Test results will be output to console. + +```bash +mkdir build && cd build && cmake .. && make tests +./tests/tests +``` + +To test coverage, simple invoke the coverage tool from the root directory. +```bash +gcovr +``` + ## Getting started ### Loading a message set diff --git a/gcovr.cfg b/gcovr.cfg index 1440025..4e242a4 100644 --- a/gcovr.cfg +++ b/gcovr.cfg @@ -1,3 +1,4 @@ exclude-throw-branches = yes filter = include/mav -exclude = include/mav/rapidxml/* \ No newline at end of file +exclude = include/mav/rapidxml/* +exclude = include/mav/picosha2/* \ No newline at end of file diff --git a/include/mav/Connection.h b/include/mav/Connection.h index e03c1f1..bc3af7e 100644 --- a/include/mav/Connection.h +++ b/include/mav/Connection.h @@ -63,9 +63,7 @@ namespace mav { struct PromiseCallback { Expectation promise; - int message_id; - int system_id; - int component_id; + std::function selector; }; using Callback = std::variant; @@ -123,9 +121,7 @@ namespace mav { } it++; } else if constexpr (std::is_same_v) { - if (message.id() == arg.message_id && - (arg.system_id == mav::ANY_ID || message.header().systemId() == arg.system_id) && - (arg.component_id == mav::ANY_ID || message.header().componentId() == arg.component_id)) { + if (arg.selector(message)) { arg.promise->set_value(message); it = _message_callbacks.erase(it); } else { @@ -199,20 +195,24 @@ namespace mav { _message_callbacks.erase(handle); } - - [[nodiscard]] Expectation expect(int message_id, int source_id=mav::ANY_ID, - int component_id=mav::ANY_ID) { - + [[nodiscard]] Expectation expect(std::function selector) { auto promise = std::make_shared>(); std::scoped_lock lock(_message_callback_mtx); CallbackHandle handle = _next_handle; - _message_callbacks[handle] = PromiseCallback{promise, message_id, source_id, component_id}; + _message_callbacks[handle] = PromiseCallback{promise, std::move(selector)}; _next_handle++; - - auto prom = std::make_shared>(); return promise; } + [[nodiscard]] Expectation expect(int message_id, int source_id=mav::ANY_ID, + int component_id=mav::ANY_ID) { + return expect([message_id, source_id, component_id](const Message &message) { + return message.id() == message_id && + (source_id == mav::ANY_ID || message.header().systemId() == source_id) && + (component_id == mav::ANY_ID || message.header().componentId() == component_id); + }); + } + [[nodiscard]] inline Expectation expect(const std::string &message_name, int source_id=mav::ANY_ID, int component_id=mav::ANY_ID) { return expect(_message_set.idForMessage(message_name), source_id, component_id); @@ -250,6 +250,10 @@ namespace mav { Message inline receive(int message_id, int timeout_ms=-1) { return receive(message_id, mav::ANY_ID, mav::ANY_ID, timeout_ms); } + + Message inline receive(std::function selector, int timeout_ms=-1) { + return receive(expect(std::move(selector)), timeout_ms); + } }; }; diff --git a/include/mav/Message.h b/include/mav/Message.h index 167853b..51625ab 100644 --- a/include/mav/Message.h +++ b/include/mav/Message.h @@ -77,10 +77,10 @@ namespace mav { class Message { friend MessageSet; private: - ConnectionPartner _source_partner; + ConnectionPartner _source_partner{}; std::array _backing_memory{}; - const MessageDefinition* _message_definition; - int _crc_offset = -1; + const MessageDefinition* _message_definition{nullptr}; + int _crc_offset{-1}; explicit Message(const MessageDefinition &message_definition) : _message_definition(&message_definition) { @@ -154,33 +154,27 @@ namespace mav { throw std::runtime_error("Unknown base type"); // should never happen } - uint64_t _computeSignatureHash48(const std::array& key) const { + uint64_t _computeSignatureHash48(const std::array& key) const { // signature = sha256_48(secret_key + header + payload + CRC + link-ID + timestamp) - constexpr size_t maxSize = 32 + MessageDefinition::HEADER_SIZE + - MessageDefinition::MAX_PAYLOAD_SIZE + - MessageDefinition::CHECKSUM_SIZE + 1 + 6; - std::array data; - size_t actualSize = 0; + picosha2::hash256_one_by_one hasher; // secret_key - std::copy_n(key.begin(), 32, data.begin() + actualSize); - actualSize += 32; + hasher.process(key.begin(), key.begin() + MessageDefinition::KEY_SIZE); // header + payload + CRC - const size_t dataSize = - MessageDefinition::HEADER_SIZE + header().len() + MessageDefinition::CHECKSUM_SIZE; - std::copy_n(_backing_memory.begin(), dataSize, data.begin() + actualSize); - actualSize += dataSize; + hasher.process(_backing_memory.begin(), _backing_memory.begin() + + MessageDefinition::HEADER_SIZE + header().len() + MessageDefinition::CHECKSUM_SIZE); // link-ID const uint8_t linkId = signature().linkId(); - serialize(linkId, data.begin() + actualSize); - actualSize += 1; + hasher.process(&linkId, &linkId + MessageDefinition::SIGNATURE_LINK_ID_SIZE); // timestamp const uint64_t timestamp = signature().timestamp(); - serialize(timestamp, data.begin() + actualSize); - actualSize += 6; + std::array timestampSerialized; + serialize(timestamp, timestampSerialized.begin()); + hasher.process(timestampSerialized.begin(), timestampSerialized.begin() + MessageDefinition::SIGNATURE_TIMESTAMP_SIZE); + hasher.finish(); std::vector hash(picosha2::k_digest_size); - picosha2::hash256(data.begin(), data.begin() + actualSize, hash.begin(), hash.end()); - return deserialize(hash.data(), 6); + hasher.get_hash_bytes(hash.begin(), hash.end()); + return deserialize(hash.data(), MessageDefinition::SIGNATURE_SIGNATURE_SIZE); } public: @@ -253,10 +247,16 @@ namespace mav { } [[nodiscard]] const Signature signature() const { + if (!isFinalized()) { + throw std::runtime_error("Unable to parse unfinalized message."); + } return Signature(&_backing_memory[MessageDefinition::HEADER_SIZE + header().len() + MessageDefinition::CHECKSUM_SIZE]); } [[nodiscard]] Signature signature() { + if (!isFinalized()) { + throw std::runtime_error("Unable to parse unfinalized message."); + } return Signature(&_backing_memory[MessageDefinition::HEADER_SIZE + header().len() + MessageDefinition::CHECKSUM_SIZE]); } @@ -478,21 +478,23 @@ namespace mav { return ss.str(); } - void sign(const std::array& key, const uint64_t& timestamp) { - signature().linkId() = 0; - signature().timestamp() = timestamp; - signature().signature() = _computeSignatureHash48(key); + [[nodiscard]] bool validate(const std::array& key) const { + return signature().signature() == _computeSignatureHash48(key); } - [[nodiscard]] bool validate(const std::array& key) const { - return signature().signature() == _computeSignatureHash48(key); + [[nodiscard]] uint32_t finalize(uint8_t seq, const Identifier &sender) { + static const std::array null_key = {}; + return finalize(seq, sender, null_key, 0, 0); } - [[nodiscard]] uint32_t finalize(uint8_t seq, const Identifier &sender, const bool sign = false) { + [[nodiscard]] uint32_t finalize(uint8_t seq, const Identifier &sender, + const std::array& key, + const uint64_t& timestamp, const uint8_t linkId = 0) { if (isFinalized()) { _unFinalize(); } + bool sign = (timestamp > 0); auto last_nonzero = std::find_if(_backing_memory.rend() - MessageDefinition::HEADER_SIZE - _message_definition->maxPayloadSize(), _backing_memory.rend(), [](const auto &v) { @@ -504,15 +506,15 @@ namespace mav { - MessageDefinition::HEADER_SIZE, 1); header().magic() = 0xFD; - header().len() = payload_size; + header().len() = static_cast(payload_size); header().incompatFlags() = sign ? 0x01 : 0x00; header().compatFlags() = 0; header().seq() = seq; if (header().systemId() == 0) { - header().systemId() = sender.system_id; + header().systemId() = static_cast(sender.system_id); } if (header().componentId() == 0) { - header().componentId() = sender.component_id; + header().componentId() = static_cast(sender.component_id); } header().msgId() = _message_definition->id(); @@ -523,7 +525,15 @@ namespace mav { _crc_offset = MessageDefinition::HEADER_SIZE + payload_size; serialize(crc.crc16(), _backing_memory.data() + _crc_offset); - return MessageDefinition::HEADER_SIZE + payload_size + MessageDefinition::CHECKSUM_SIZE; + int signature_size = 0; + if (sign) { + signature().linkId() = linkId; + signature().timestamp() = timestamp; + signature().signature() = _computeSignatureHash48(key); + signature_size = MessageDefinition::SIGNATURE_SIZE; + } + + return MessageDefinition::HEADER_SIZE + payload_size + MessageDefinition::CHECKSUM_SIZE + signature_size; } [[nodiscard]] const uint8_t* data() const { @@ -533,4 +543,4 @@ namespace mav { } // namespace libmavlink` -#endif //MAV_DYNAMICMESSAGE_H \ No newline at end of file +#endif //MAV_DYNAMICMESSAGE_H diff --git a/include/mav/MessageDefinition.h b/include/mav/MessageDefinition.h index f4595de..81396a8 100644 --- a/include/mav/MessageDefinition.h +++ b/include/mav/MessageDefinition.h @@ -367,8 +367,12 @@ namespace mav { static constexpr int MAX_PAYLOAD_SIZE = 255; static constexpr int HEADER_SIZE = 10; static constexpr int CHECKSUM_SIZE = 2; - static constexpr int SIGNATURE_SIZE = 13; + static constexpr int SIGNATURE_LINK_ID_SIZE = 1; + static constexpr int SIGNATURE_TIMESTAMP_SIZE = 6; + static constexpr int SIGNATURE_SIGNATURE_SIZE = 6; + static constexpr int SIGNATURE_SIZE = SIGNATURE_LINK_ID_SIZE + SIGNATURE_TIMESTAMP_SIZE + SIGNATURE_SIGNATURE_SIZE; static constexpr int MAX_MESSAGE_SIZE = MAX_PAYLOAD_SIZE + HEADER_SIZE + CHECKSUM_SIZE + SIGNATURE_SIZE; + static constexpr int KEY_SIZE = 32; [[nodiscard]] inline const std::string& name() const { return _name; diff --git a/include/mav/Network.h b/include/mav/Network.h index b9c42d3..112724d 100644 --- a/include/mav/Network.h +++ b/include/mav/Network.h @@ -43,6 +43,7 @@ #include #include #include +#include #include #include #include "Connection.h" @@ -83,6 +84,7 @@ namespace mav { [[nodiscard]] virtual bool isConnectionOriented() const { return false; }; + virtual ~NetworkInterface() = default; }; @@ -144,6 +146,10 @@ namespace mav { } }; + static uint64_t _get_timestamp_function_default() { + const auto now = std::chrono::system_clock::now(); + return std::chrono::duration_cast(now.time_since_epoch()).count(); + } class NetworkRuntime { private: @@ -157,7 +163,8 @@ namespace mav { std::mutex _heartbeat_message_mutex; StreamParser _parser; Identifier _own_id; - std::array _key; + bool _sign; + std::array _key; std::function _get_timestamp_function; std::mutex _connections_mutex; std::mutex _send_mutex; @@ -172,11 +179,11 @@ namespace mav { std::function&)> _on_connection_lost; void _sendMessage(Message &message, const ConnectionPartner &partner) { - const bool sign = bool(_get_timestamp_function); - int wire_length = static_cast(message.finalize(_seq++, _own_id, sign)); - if (sign) { - message.sign(_key, _get_timestamp_function()); - wire_length += MessageDefinition::SIGNATURE_SIZE; + int wire_length; + if (_sign) { + wire_length = static_cast(message.finalize(_seq++, _own_id, _key, _get_timestamp_function())); + } else { + wire_length = static_cast(message.finalize(_seq++, _own_id)); } std::unique_lock lock(_send_mutex); _interface.send(message.data(), wire_length, partner); @@ -262,7 +269,7 @@ namespace mav { _expireConnection(connection.second, std::make_exception_ptr(e)); } } - } catch (NetworkInterfaceInterrupt &e) { + } catch (NetworkInterfaceInterrupt&) { _should_terminate.store(true); } } @@ -296,7 +303,7 @@ namespace mav { _expireConnection(connection.second, std::make_exception_ptr(e)); } } - } catch (NetworkInterfaceInterrupt &e) { + } catch (NetworkInterfaceInterrupt&) { _should_terminate.store(true); } @@ -333,6 +340,7 @@ namespace mav { std::function&)> on_connection_lost = {}) : _interface(interface), _message_set(message_set), _parser(_message_set, _interface), _own_id(own_id), + _sign(false), _get_timestamp_function(_get_timestamp_function_default), _on_connection(std::move(on_connection)), _on_connection_lost(std::move(on_connection_lost)) { _receive_thread = std::thread{ @@ -411,16 +419,19 @@ namespace mav { _heartbeat_message = std::nullopt; } - void setGetTimestampFunction(std::function function) { - _get_timestamp_function = function; + void injectMessage(Message &message) { + _sendMessage(message, {}); } - void setKey(std::array key) { + void enableMessageSigning(std::array key, + std::function timestampFunction = _get_timestamp_function_default) { + _sign = true; _key = key; + _get_timestamp_function = timestampFunction; } - void sendMessage(Message &message) { - _sendMessage(message, {}); + void disableMessageSigning() { + _sign = false; } void stop() { diff --git a/include/mav/utils.h b/include/mav/utils.h index 65736c1..fc7cf8e 100644 --- a/include/mav/utils.h +++ b/include/mav/utils.h @@ -32,6 +32,7 @@ * ****************************************************************************/ +#include #include #include #include @@ -77,14 +78,14 @@ namespace mav { }; template - inline void serialize(const T &v, uint8_t* destination) { + inline void serialize(const T &v, uint8_t* destination) noexcept { auto src_ptr = static_cast(static_cast(&v)); std::copy(src_ptr, src_ptr + sizeof(T), destination); } template - inline T deserialize(const uint8_t* source, int deserialize_size) { + inline T deserialize(const uint8_t* source, int deserialize_size) noexcept { // in case we do not have any bytes to read, we return 0 if (deserialize_size <= 0) { return T{0}; @@ -99,7 +100,7 @@ namespace mav { } - inline uint64_t millis() { + inline uint64_t millis() noexcept { return std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()).count(); } @@ -115,21 +116,25 @@ namespace mav { static constexpr formatEndType end{}; static constexpr formatEndType toString{}; +#if __GNUC__ > 9 + StringFormat() noexcept = default; +#endif + template - StringFormat& operator<< (const T &value) { + StringFormat& operator<< (const T &value) noexcept { _stream << value; return *this; } - std::string operator<< (const formatEndType&) { + std::string operator<< (const formatEndType&) noexcept { return _stream.str(); } - std::string str() const { + std::string str() const noexcept { return _stream.str(); } - explicit operator std::string() const { + explicit operator std::string() const noexcept { return _stream.str(); } }; @@ -203,25 +208,25 @@ namespace mav { } - template - A _packUnpack(B b) { - union U { - A a; - B b; - }; - U u; - u.b = b; - return u.a; + template + To _packUnpack(From o) { + static_assert(sizeof(To) == sizeof(From), "Cannot pack/unpack different sizes"); + static_assert(std::is_trivially_copyable::value, "Cannot pack/unpack non-trivially copyable types"); + static_assert(std::is_trivially_copyable::value, "Cannot pack/unpack non-trivially copyable types"); + To result; + std::memcpy(&result, &o, sizeof(To)); + return result; } + - template - T floatUnpack(float f) { - return _packUnpack(f); + template + To floatUnpack(float f) { + return _packUnpack(f); } - template - float floatPack(T o) { - return _packUnpack(o); + template + float floatPack(From o) { + return _packUnpack(o); } diff --git a/sonar-project.properties b/sonar-project.properties new file mode 100644 index 0000000..4d9c71f --- /dev/null +++ b/sonar-project.properties @@ -0,0 +1,13 @@ +sonar.projectKey = Auterion_libmav +sonar.organization = auterion + +sonar.projectName = libmav + +sonar.sources = include/,tests/ + +sonar.exclusions = include/mav/rapidxml/*,tests/doctest.h,include/mav/picosha2/* + +sonar.coverage.exclusions = tests/**/*,include/mav/rapidxml/*,include/mav/picosha2/* +sonar.cpd.exclusions = tests/**/* + +sonar.sourceEncoding = UTF-8 diff --git a/tests/Message.cpp b/tests/Message.cpp index 051473c..166a9f0 100644 --- a/tests/Message.cpp +++ b/tests/Message.cpp @@ -34,12 +34,25 @@ TEST_CASE("Message set creation") { description description + + description + description + description + description + description + description + description + description + description + description + description + )""""); REQUIRE(message_set.contains("BIG_MESSAGE")); - REQUIRE_EQ(message_set.size(), 2); + REQUIRE_EQ(message_set.size(), 3); auto message = message_set.create("BIG_MESSAGE"); CHECK_EQ(message.id(), message_set.idForMessage("BIG_MESSAGE")); @@ -78,7 +91,7 @@ TEST_CASE("Message set creation") { SUBCASE("Have fields truncated by zero-elision") { message.set("int64_field", 34); // since largest field, will be at the end of the message CHECK_EQ(message.get("int64_field"), 34); - auto ret = message.finalize(1, {2,3}); + auto ret = message.finalize(1, {2, 3}); CHECK_EQ(message.get("int64_field"), 34); } @@ -130,19 +143,19 @@ TEST_CASE("Message set creation") { SUBCASE("Can set values with initializer list API") { message.set({ - {"uint8_field", 0x12}, - {"int8_field", 0x12}, - {"uint16_field", 0x1234}, - {"int16_field", 0x1234}, - {"uint32_field", 0x12345678}, - {"int32_field", 0x12345678}, - {"uint64_field", 0x1234567890ABCDEF}, - {"int64_field", 0x1234567890ABCDEF}, - {"double_field", 0.123456789}, - {"float_field", 0.123456789}, - {"char_arr_field", "Hello World!"}, - {"float_arr_field", std::vector{1.0, 2.0, 3.0}}, - {"int32_arr_field", std::vector{1, 2, 3}}}); + {"uint8_field", 0x12}, + {"int8_field", 0x12}, + {"uint16_field", 0x1234}, + {"int16_field", 0x1234}, + {"uint32_field", 0x12345678}, + {"int32_field", 0x12345678}, + {"uint64_field", 0x1234567890ABCDEF}, + {"int64_field", 0x1234567890ABCDEF}, + {"double_field", 0.123456789}, + {"float_field", 0.123456789}, + {"char_arr_field", "Hello World!"}, + {"float_arr_field", std::vector{1.0, 2.0, 3.0}}, + {"int32_arr_field", std::vector{1, 2, 3}}}); CHECK_EQ(static_cast(message["uint8_field"]), 0x12); CHECK_EQ(static_cast(message["int8_field"]), 0x12); @@ -201,6 +214,9 @@ TEST_CASE("Message set creation") { message["float_field"].floatPack(0x23456789); CHECK_EQ(message["float_field"].floatUnpack(), 0x23456789); + + CHECK_THROWS_AS(message["float_field"].floatUnpack(), std::runtime_error); + CHECK_THROWS_AS(message["float_field"].floatUnpack>(), std::runtime_error); } SUBCASE("Set and get a single field in array outside of range") { @@ -258,6 +274,129 @@ TEST_CASE("Message set creation") { CHECK_EQ(this_test_message.get("field4"), 0); } + SUBCASE("Test all array conversions") { + auto this_test_message = message_set.create("ARRAY_ONLY_MESSAGE"); + this_test_message["field1"] = "Hello"; + this_test_message["field2"] = std::vector{1, 2, 3}; + this_test_message["field3"] = std::vector{4, 5, 6}; + this_test_message["field4"] = std::vector{7, 8, 9}; + this_test_message["field5"] = std::vector{10, 11, 12}; + this_test_message["field6"] = std::vector{13, 14, 15}; + this_test_message["field7"] = std::vector{16, 17, 18}; + this_test_message["field8"] = std::vector{19, 20, 21}; + this_test_message["field9"] = std::vector{22, 23, 24}; + this_test_message["field10"] = std::vector{25, 26, 27}; + this_test_message["field11"] = std::vector{28, 29, 30}; + + CHECK_EQ(this_test_message["field1"].as(), "Hello"); + CHECK_EQ(this_test_message["field2"].as>(), std::vector{1, 2, 3}); + CHECK_EQ(this_test_message["field3"].as>(), std::vector{4, 5, 6}); + CHECK_EQ(this_test_message["field4"].as>(), std::vector{7, 8, 9}); + CHECK_EQ(this_test_message["field5"].as>(), std::vector{10, 11, 12}); + CHECK_EQ(this_test_message["field6"].as>(), std::vector{13, 14, 15}); + CHECK_EQ(this_test_message["field7"].as>(), std::vector{16, 17, 18}); + CHECK_EQ(this_test_message["field8"].as>(), std::vector{19, 20, 21}); + CHECK_EQ(this_test_message["field9"].as>(), std::vector{22, 23, 24}); + CHECK_EQ(this_test_message["field10"].as>(), std::vector{25, 26, 27}); + CHECK_EQ(this_test_message["field11"].as>(), std::vector{28, 29, 30}); + } + SUBCASE("Can get as native type variant") { + message.setFromNativeTypeVariant("uint8_field", {static_cast(1)}); + message.setFromNativeTypeVariant("int8_field", {static_cast(2)}); + message.setFromNativeTypeVariant("uint16_field", {static_cast(3)}); + message.setFromNativeTypeVariant("int16_field", {static_cast(4)}); + message.setFromNativeTypeVariant("uint32_field", {static_cast(5)}); + message.setFromNativeTypeVariant("int32_field", {static_cast(6)}); + message.setFromNativeTypeVariant("uint64_field", {static_cast(7)}); + message.setFromNativeTypeVariant("int64_field", {static_cast(8)}); + message.setFromNativeTypeVariant("double_field", {9.0}); + message.setFromNativeTypeVariant("float_field", {10.0f}); + message.setFromNativeTypeVariant("char_arr_field", {"Hello World!"}); + message.setFromNativeTypeVariant("float_arr_field", {std::vector{1.0, 2.0, 3.0}}); + message.setFromNativeTypeVariant("int32_arr_field", {std::vector{4, 5, 6}}); + CHECK_EQ(message.get("uint8_field"), 1); + CHECK_EQ(message.get("int8_field"), 2); + CHECK_EQ(message.get("uint16_field"), 3); + CHECK_EQ(message.get("int16_field"), 4); + CHECK_EQ(message.get("uint32_field"), 5); + CHECK_EQ(message.get("int32_field"), 6); + CHECK_EQ(message.get("uint64_field"), 7); + CHECK_EQ(message.get("int64_field"), 8); + CHECK_EQ(message.get("double_field"), doctest::Approx(9.0)); + CHECK_EQ(message.get("float_field"), doctest::Approx(10.0)); + CHECK_EQ(message.get("char_arr_field"), "Hello World!"); + CHECK_EQ(message.get>("float_arr_field"), std::vector{1.0, 2.0, 3.0}); + CHECK_EQ(message.get>("int32_arr_field"), std::vector{4, 5, 6}); + + + + CHECK(std::holds_alternative(message.getAsNativeTypeInVariant("uint8_field"))); + CHECK(std::holds_alternative(message.getAsNativeTypeInVariant("int8_field"))); + CHECK(std::holds_alternative(message.getAsNativeTypeInVariant("uint16_field"))); + CHECK(std::holds_alternative(message.getAsNativeTypeInVariant("int16_field"))); + CHECK(std::holds_alternative(message.getAsNativeTypeInVariant("uint32_field"))); + CHECK(std::holds_alternative(message.getAsNativeTypeInVariant("int32_field"))); + CHECK(std::holds_alternative(message.getAsNativeTypeInVariant("uint64_field"))); + CHECK(std::holds_alternative(message.getAsNativeTypeInVariant("int64_field"))); + CHECK(std::holds_alternative(message.getAsNativeTypeInVariant("double_field"))); + CHECK(std::holds_alternative(message.getAsNativeTypeInVariant("float_field"))); + CHECK(std::holds_alternative(message.getAsNativeTypeInVariant("char_arr_field"))); + CHECK(std::holds_alternative>(message.getAsNativeTypeInVariant("float_arr_field"))); + CHECK(std::holds_alternative>(message.getAsNativeTypeInVariant("int32_arr_field"))); + + + auto this_test_message = message_set.create("ARRAY_ONLY_MESSAGE"); + CHECK(std::holds_alternative(this_test_message.getAsNativeTypeInVariant("field1"))); + CHECK(std::holds_alternative>(this_test_message.getAsNativeTypeInVariant("field2"))); + CHECK(std::holds_alternative>(this_test_message.getAsNativeTypeInVariant("field3"))); + CHECK(std::holds_alternative>(this_test_message.getAsNativeTypeInVariant("field4"))); + CHECK(std::holds_alternative>(this_test_message.getAsNativeTypeInVariant("field5"))); + CHECK(std::holds_alternative>(this_test_message.getAsNativeTypeInVariant("field6"))); + CHECK(std::holds_alternative>(this_test_message.getAsNativeTypeInVariant("field7"))); + CHECK(std::holds_alternative>(this_test_message.getAsNativeTypeInVariant("field8"))); + CHECK(std::holds_alternative>(this_test_message.getAsNativeTypeInVariant("field9"))); + CHECK(std::holds_alternative>(this_test_message.getAsNativeTypeInVariant("field10"))); + CHECK(std::holds_alternative>(this_test_message.getAsNativeTypeInVariant("field11"))); + } + + SUBCASE("Test toString") { + message["uint8_field"] = 1; + message["int8_field"] = -2; + message["uint16_field"] = 3; + message["int16_field"] = -4; + message["uint32_field"] = 5; + message["int32_field"] = -6; + message["uint64_field"] = 7; + message["int64_field"] = 8; + message["double_field"] = 9.0; + message["float_field"] = 10.0; + message["char_arr_field"] = "Hello World!"; + message["float_arr_field"] = std::vector{1.0, 2.0, 3.0}; + message["int32_arr_field"] = std::vector{4, 5, 6}; + CHECK_EQ(message.toString(), + "Message ID 9915 (BIG_MESSAGE) \n char_arr_field: \"Hello World!\"\n double_field: 9\n float_arr_field: 1, 2, 3\n float_field: 10\n int16_field: -4\n int32_arr_field: 4, 5, 6\n int32_field: -6\n int64_field: 8\n int8_field: -2\n uint16_field: 3\n uint32_field: 5\n uint64_field: 7\n uint8_field: 1\n"); + } + + SUBCASE("Sign a packet") { + + std::array key; + for (int i = 0 ; i < 32; i++) key[i] = i; + + uint64_t timestamp = 770479200; + + // Attempt to access signature before signed (const & non-const versions) + const auto const_message = message_set.create("UINT8_ONLY_MESSAGE"); + CHECK_THROWS_AS(message.signature(), std::runtime_error); + CHECK_THROWS_AS(const_message.signature(), std::runtime_error); + + uint32_t wire_size = message.finalize(5, {6, 7}, key, timestamp); + + CHECK_EQ(wire_size, 26); + CHECK(message.header().isSigned()); + CHECK_NE(message.signature().signature(), 0); + CHECK_EQ(message.signature().timestamp(), timestamp); + CHECK(message.validate(key)); + } } diff --git a/tests/Network.cpp b/tests/Network.cpp index 744b309..18b4e03 100644 --- a/tests/Network.cpp +++ b/tests/Network.cpp @@ -90,6 +90,9 @@ class DummyInterface : public NetworkInterface { } }; +uint64_t getTimestamp() { + return 770479200; +} TEST_CASE("Create network runtime") { @@ -122,6 +125,9 @@ TEST_CASE("Create network runtime") { DummyInterface interface; NetworkRuntime network({253, 1}, message_set, interface); + std::array key; + for (int i = 0 ; i < 32; i++) key[i] = i; + // send a heartbeat message, to establish a connection interface.addToReceiveQueue("\xfd\x09\x00\x00\x00\xfd\x01\x00\x00\x00\x04\x00\x00\x00\x01\x02\x03\x05\x06\x77\x53"s, interface_partner); auto connection = network.awaitConnection(); @@ -159,6 +165,19 @@ TEST_CASE("Create network runtime") { CHECK_EQ(message.name(), "TEST_MESSAGE"); } + SUBCASE("Receiver re-synchronizes when garbage data between messages") { + interface.reset(); + auto expectation_1 = connection->expect("TEST_MESSAGE"); + auto expectation_2 = connection->expect("TEST_MESSAGE"); + interface.addToReceiveQueue("\xfd\x10\x00\x00\x01\x61\x61\xbc\x26\x00\x2a\x00\x00\x00\x48\x65\x6c\x6c\x6f\x20\x57\x6f\x72\x6c\x64\x21\x53\xd9"s, interface_partner); + interface.addToReceiveQueue("this is garbage data"s, interface_partner); + interface.addToReceiveQueue("\xfd\x10\x00\x00\x01\x61\x61\xbc\x26\x00\x2a\x00\x00\x00\x48\x65\x6c\x6c\x6f\x20\x57\x6f\x72\x6c\x64\x21\x53\xd9"s, interface_partner); + auto message_1 = connection->receive(expectation_1); + CHECK_EQ(message_1.name(), "TEST_MESSAGE"); + auto message_2 = connection->receive(expectation_2); + CHECK_EQ(message_2.name(), "TEST_MESSAGE"); + } + SUBCASE("Can not receive message from wrong partner") { interface.reset(); auto expectation = connection->expect("TEST_MESSAGE"); @@ -200,4 +219,50 @@ TEST_CASE("Create network runtime") { CHECK_EQ(message.name(), "HEARTBEAT"); CHECK(connection->alive()); } + + SUBCASE("Enable message signing") { + auto message = message_set.create("TEST_MESSAGE")({ + {"value", 42}, + {"text", "Hello World!"} + }); + interface.reset(); + network.enableMessageSigning(key); + connection->send(message); + CHECK(message.header().isSigned()); + // don't check anything after link_id in signature as the timestamp is dependent on current time + bool found = (interface.sendSpongeContains( + "\xfd\x10\x01\x00\x00\xfd\x01\xbc\x26\x00\x2a\x00\x00\x00\x48\x65\x6c\x6c\x6f\x20\x57\x6f\x72\x6c\x64\x21\xfd\x33\x00"s, + interface_partner)); + CHECK(found); + } + + SUBCASE("Enable message signing with custom timestamp function") { + auto message = message_set.create("TEST_MESSAGE")({ + {"value", 42}, + {"text", "Hello World!"} + }); + interface.reset(); + network.enableMessageSigning(key, getTimestamp); + connection->send(message); + CHECK(message.header().isSigned()); + bool found = (interface.sendSpongeContains( + "\xfd\x10\x01\x00\x00\xfd\x01\xbc\x26\x00\x2a\x00\x00\x00\x48\x65\x6c\x6c\x6f\x20\x57\x6f\x72\x6c\x64\x21\xfd\x33\x00\x60\x94\xec\x2d\x00\x00\x7b\xab\xfa\x1a\xed\xf9"s, + interface_partner)); + CHECK(found); + } + + SUBCASE("Disable message signing") { + auto message = message_set.create("TEST_MESSAGE")({ + {"value", 42}, + {"text", "Hello World!"} + }); + interface.reset(); + network.disableMessageSigning(); + connection->send(message); + CHECK(!message.header().isSigned()); + bool found = (interface.sendSpongeContains( + "\xfd\x10\x00\x00\x00\xfd\x01\xbc\x26\x00\x2a\x00\x00\x00\x48\x65\x6c\x6c\x6f\x20\x57\x6f\x72\x6c\x64\x21\x86\x37"s, + interface_partner)); + CHECK(found); + } }