diff --git a/tests/tt_metal/tt_metal/stl/CMakeLists.txt b/tests/tt_metal/tt_metal/stl/CMakeLists.txt index 0f1100b0e6f..9b7c87fb0c7 100644 --- a/tests/tt_metal/tt_metal/stl/CMakeLists.txt +++ b/tests/tt_metal/tt_metal/stl/CMakeLists.txt @@ -1,6 +1,7 @@ set(UNIT_TESTS_STL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/test_any_range.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_slotmap.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test_strong_type.cpp ) add_executable(unit_tests_stl ${UNIT_TESTS_STL_SRC}) diff --git a/tests/tt_metal/tt_metal/stl/test_strong_type.cpp b/tests/tt_metal/tt_metal/stl/test_strong_type.cpp new file mode 100644 index 00000000000..6983cca7f84 --- /dev/null +++ b/tests/tt_metal/tt_metal/stl/test_strong_type.cpp @@ -0,0 +1,68 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include + +#include "tt_metal/tt_stl/strong_type.hpp" + +using MyIntId = tt::stl::StrongType; +using MyStringId = tt::stl::StrongType; + +namespace tt::stl { +namespace { + +using ::testing::ElementsAre; +using ::testing::IsNull; +using ::testing::UnorderedElementsAre; + +TEST(StrongTypeTest, Basic) { + MyIntId my_int_id1(42); + MyIntId my_int_id2(43); + + EXPECT_EQ(*my_int_id1, 42); + EXPECT_LT(*my_int_id1, *my_int_id2); + + my_int_id1 = MyIntId(43); + EXPECT_EQ(my_int_id1, my_int_id2); +} + +TEST(StrongTypeTest, UseInContainers) { + std::unordered_set unordered; + std::set ordered; + + unordered.insert(MyIntId(42)); + unordered.insert(MyIntId(43)); + + ordered.insert(MyIntId(1)); + ordered.insert(MyIntId(2)); + ordered.insert(MyIntId(3)); + + EXPECT_THAT(unordered, UnorderedElementsAre(MyIntId(42), MyIntId(43))); + EXPECT_THAT(ordered, ElementsAre(MyIntId(1), MyIntId(2), MyIntId(3))); +} + +TEST(StrongTypeTest, StreamingOperator) { + std::stringstream ss; + ss << MyStringId("hello world"); + EXPECT_EQ(ss.str(), "hello world"); +} + +TEST(StrongTypeTest, MoveOnlyType) { + using MoveOnlyType = StrongType, struct MoveOnlyTag>; + + MoveOnlyType from(std::make_unique(42)); + EXPECT_EQ(**from, 42); + + MoveOnlyType to = std::move(from); + + // NOLINTNEXTLINE(bugprone-use-after-move) + EXPECT_THAT(*from, IsNull()); + EXPECT_EQ(**to, 42); +} + +} // namespace +} // namespace tt::stl diff --git a/tt_metal/tt_stl/strong_type.hpp b/tt_metal/tt_stl/strong_type.hpp new file mode 100644 index 00000000000..f69309f8189 --- /dev/null +++ b/tt_metal/tt_stl/strong_type.hpp @@ -0,0 +1,82 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace tt::stl { + +// `StrongType` provides a strongly-typed wrapper around a value to prevent accidental type conversions. +// +// This is useful when creating aliases that rely on a primitive type; for example instead of using `uint32_t` as +// `DeviceId` directly, wrap in `StrongType` to prevent accidental assignment +// from `uint32_t`. Here, the 'tag' is used to disambiguate the type, and to create distinct wrappers relying on +// `uint32_t`. +// +// +// Example usage: +// +// // Create strong types. +// // `struct`s for the tag can be supplied as shown, despite of being incomplete: +// +// using UserId = StrongType; +// using GroupId = StrongType; +// using Username = StrongType; +// +// +// // The different types cannot be assigned to each other: +// UserId user_id(42); +// GroupId group_id(45); +// user_id = group_id; // does not compile! +// +// Username name("john_doe"); +// name = "jane_doe"; // does not compile! +// name = Username("jane_doe"); // instantiate explicitly. +// +// // Access the underlying value: +// uint32_t raw_user_id = *user_id; +// assert(*user_id < *group_id); +// +// // Strong types work with standard containers and the streaming operator, as long as the underlying type is +// // hashable and comparable. +// +// std::unordered_set user_set; +// user_set.insert(UserId(1)); +// user_set.insert(UserId(2)); + +// std::map user_map; +// user_map.emplace(UserId(1), Username("John Doe")); +// +// std::cout << user_map.at(UserId(1)) << std::endl; // "John Doe" +// +template +class StrongType { +public: + explicit StrongType(T v) : value_(std::move(v)) {} + + StrongType(const StrongType&) = default; + StrongType(StrongType&&) = default; + StrongType& operator=(const StrongType&) = default; + StrongType& operator=(StrongType&&) = default; + + const T& operator*() const { return value_; } + + auto operator<=>(const StrongType&) const = default; + +private: + T value_; +}; + +} // namespace tt::stl + +template +std::ostream& operator<<(std::ostream& os, const tt::stl::StrongType& h) { + return os << *h; +} + +template +struct std::hash> { + std::size_t operator()(const tt::stl::StrongType& h) const noexcept { return std::hash{}(*h); } +};