Skip to content

Commit

Permalink
#8658: Migrate composite unary ops to C++
Browse files Browse the repository at this point in the history
  • Loading branch information
eyonland committed Jun 6, 2024
1 parent a5a6ddb commit ce75c4f
Show file tree
Hide file tree
Showing 9 changed files with 550 additions and 159 deletions.
14 changes: 1 addition & 13 deletions tests/ttnn/unit_tests/operations/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def test_rad2deg(device, h, w):
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_cbrt(device, h, w):
torch_cbrt = ttnn.get_golden_function(ttnn.cbrt)
run_math_unary_test(device, h, w, ttnn.cbrt, torch_cbrt, pcc=0.999)


Expand Down Expand Up @@ -270,19 +271,6 @@ def run_math_unary_test_range(device, h, w, ttnn_function, torch_function, pcc=0
assert_with_pcc(torch_output_tensor, output_tensor, pcc)


def torch_cbrt(x, *args, **kwargs):
return torch.sgn(x) * torch.pow(torch.abs(x), 1.0 / 3)


def torch_multigammaln(x, *args, **kwargs):
result = torch.lgamma(x)
result += torch.lgamma(x - 0.5)
result += torch.lgamma(x - 1.0)
result += torch.lgamma(x - 1.5)
result += 3.434189657547
return result


@pytest.mark.parametrize("h", [5])
@pytest.mark.parametrize("w", [5])
def test_multigammaln(device, h, w):
Expand Down
4 changes: 4 additions & 0 deletions tt_eager/tt_dnn/op_library/auto_format.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class AutoFormat {
auto c = pad_c ? round_up(unpadded_shape[1], TILE_WIDTH) : unpadded_shape[1];
auto h = pad_h ? round_up(unpadded_shape[2], TILE_HEIGHT) : unpadded_shape[2];
auto w = pad_w ? round_up(unpadded_shape[3], TILE_WIDTH) : unpadded_shape[3];
// auto n = pad_n ? round_up(unpadded_shape.rank() >= 4 ? unpadded_shape[-4] : 1, TILE_HEIGHT) : unpadded_shape.rank() >= 4 ? unpadded_shape[-4] : 1;
// auto c = pad_c ? round_up(unpadded_shape.rank() >= 3 ? unpadded_shape[-3] : 1, TILE_WIDTH) : unpadded_shape.rank() >= 3 ? unpadded_shape[-3] : 1;
// auto h = pad_h ? round_up(unpadded_shape[-2], TILE_HEIGHT) : unpadded_shape[-2];
// auto w = pad_w ? round_up(unpadded_shape[-1], TILE_WIDTH) : unpadded_shape[-1];
Shape padded_shape = {n, c, h, w};
return padded_shape;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ operation::ProgramWithCallbacks bcast_sharded_h(const Tensor &a, const Tensor &b
const auto bshape = b.get_legacy_shape();
uint32_t N = ashape[0], C = ashape[1], H = ashape[2], W = ashape[3];
uint32_t bN = bshape[0], bC = bshape[1], bH = bshape[2], bW = bshape[3];
// uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1, C = ashape.rank() >= 3 ? ashape[-3] : 1, H = ashape[-2], W = ashape[-1];
// uint32_t bN = bshape.rank() >= 4 ? bshape[-4] : 1, bC = bshape.rank() >= 3 ? bshape[-3] : 1, bH = bshape[-2], bW = bshape[-1];
uint32_t NC = N*C;


Expand Down
14 changes: 12 additions & 2 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,21 @@ namespace tt_metal {

Tensor mk_zero_tensor_like(const Tensor& reference_tensor, const MemoryConfig& output_mem_config) {
// Tensor zero_like = bcast(reference_tensor, , BcastOpMath::MUL, BcastOpDim::HW);
//Tensor zero = mk_tiled_scalar(0.0f, reference_tensor.get_dtype(), reference_tensor.device());
Tensor zero = mk_tiled_scalar(0.0f, reference_tensor.get_dtype());
Tensor zero_like = bcast(reference_tensor, zero, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config);
//zero.deallocate();
return zero_like;
}

// TODO: enable zeroes(), ones() and eye() type functions on-device using this type of logic
template <typename T>
Tensor mk_filled_tensor_like(const Tensor& reference_tensor, T val, const MemoryConfig& output_mem_config) {
Tensor k = mk_tiled_scalar(val, reference_tensor.get_dtype());
Tensor k = mk_tiled_scalar(val, reference_tensor.get_dtype(), reference_tensor.device());
Tensor zero_like = mk_zero_tensor_like(reference_tensor, output_mem_config);
Tensor result = bcast(zero_like, k, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config);
//k.deallocate();
//zero_like.deallocate();
return result;
}

Expand Down Expand Up @@ -407,7 +411,13 @@ Tensor mac(const Tensor& a, const Tensor& b, const Tensor& c, const MemoryConfig
Tensor _mac_overload(const Tensor& a, float b, float c, const MemoryConfig& output_mem_config) {
Tensor t_b = mk_scalar(b);
Tensor t_c = mk_scalar(c);
return mac(a, t_b, t_c, output_mem_config);
// Tensor t_b = mk_tiled_scalar(b, a.device());
// Tensor t_c = mk_tiled_scalar(c, a.device());

Tensor return_tensor = mac(a, t_b, t_c, output_mem_config);
// t_b.deallocate();
// t_c.deallocate();
return return_tensor;
}
Tensor mac(const Tensor& input_a, float b, float c, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _mac_overload)(input_a, b, c, output_mem_config);
Expand Down
37 changes: 33 additions & 4 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,49 @@ using binary_tensor_op_t = Tensor(const Tensor& a, const Tensor& b);

// Note: inline doesn't allow pybind to work well so we keep few function not inlined.

// template <typename T>
// Tensor create(T scalar, short rank, DataType data_type, Layout layout, std::optional<Device*> device = std::nullopt);
// auto host_buffer = owned_buffer::create<::bfloat16>(static_cast<std::size_t>(TILE_HEIGHT * TILE_WIDTH));
// host_buffer[0] = scalar;
// Tensor scalar_tensor_host = Tensor(
// OwnedStorage{host_buffer},
// ttnn::Shape(std::array<std::uint32_t, 2>{1, 1}, std::array<std::uint32_t, 2>{TILE_HEIGHT, TILE_WIDTH}),
// DataType::BFLOAT16,
// Layout::TILE);
// if (device.has_value()) {
// return scalar_tensor_host.to(device.value());
// }
// return scalar_tensor_host;
// }

template <typename T>
Tensor mk_scalar(T value) {
Tensor mk_scalar(T value, std::optional<Device*> device = std::nullopt) {
assert(std::is_scalar<T>::value && "T should be scalar");
std::array<unsigned int, 4> shape = {1, 1, 1, 1};
auto buffer = owned_buffer::create(std::vector{bfloat16(value)});
Tensor scalar = Tensor(OwnedStorage{buffer}, shape, DataType::BFLOAT16, Layout::ROW_MAJOR);
if (device.has_value()) {
scalar = AutoFormat::move_tensor_to_device(scalar, device.value());
}
return scalar;
}

template <typename T>
Tensor mk_tiled_scalar(T value) {
Tensor mk_tiled_scalar(T value, std::optional<Device*> device = std::nullopt) {
assert(std::is_scalar<T>::value && "T should be scalar");
std::array<unsigned int, 4> shape = {1, 1, TILE_HEIGHT, TILE_WIDTH};
std::vector<bfloat16> buffer_vec(TILE_HW, bfloat16(0));
buffer_vec[0] = bfloat16(value);
auto buffer = owned_buffer::create(std::move(buffer_vec));
Tensor scalar = Tensor(OwnedStorage{buffer}, shape, DataType::BFLOAT16, Layout::TILE);
if (device.has_value()) {
scalar = AutoFormat::move_tensor_to_device(scalar, device.value());
}
return scalar;
}

template <typename T>
Tensor mk_tiled_scalar(T value, DataType dtype) {
Tensor mk_tiled_scalar(T value, DataType dtype, std::optional<Device*> device = std::nullopt) {
assert(std::is_scalar<T>::value && "T should be scalar");
std::array<unsigned int, 4> shape = {1, 1, TILE_HEIGHT, TILE_WIDTH};
if(dtype == DataType::BFLOAT8_B)
Expand All @@ -52,12 +73,20 @@ Tensor mk_tiled_scalar(T value, DataType dtype) {
buffer_vec[0] = float(value);
auto output_packed_data = pack_fp32_vec_as_bfp8_tiles(buffer_vec, /*row_major_input=*/false, /*is_exp_a=*/false);
auto output_uint32_buffer = owned_buffer::create<uint32_t>(std::move(output_packed_data));
return Tensor(std::move(OwnedStorage{std::move(output_uint32_buffer)}), shape, DataType::BFLOAT8_B, Layout::TILE);
if (device.has_value()) {
return AutoFormat::move_tensor_to_device(Tensor(std::move(OwnedStorage{std::move(output_uint32_buffer)}), shape, DataType::BFLOAT8_B, Layout::TILE), device.value());
}
else{
return Tensor(std::move(OwnedStorage{std::move(output_uint32_buffer)}), shape, DataType::BFLOAT8_B, Layout::TILE);
}
}
std::vector<bfloat16> buffer_vec(TILE_HW, bfloat16(0));
buffer_vec[0] = bfloat16(value);
auto buffer = owned_buffer::create(std::move(buffer_vec));
Tensor scalar = Tensor(OwnedStorage{buffer}, shape, DataType::BFLOAT16, Layout::TILE);
if (device.has_value()) {
scalar = AutoFormat::move_tensor_to_device(scalar, device.value());
}
return scalar;
}
// Function: softshrink
Expand Down
161 changes: 160 additions & 1 deletion ttnn/cpp/pybind11/operations/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,143 @@ void bind_unary_operation(py::module& module, const unary_operation_t& operation
module,
operation,
doc,
ttnn::pybind_arguments_t{py::arg("input_tensor"), py::kw_only(), py::arg("memory_config") = std::nullopt, py::arg("output_tensor") = std::nullopt});
ttnn::pybind_overload_t{
[](const unary_operation_t& self,
const Tensor& input_tensor,
const std::optional<MemoryConfig>& memory_config) { return self(input_tensor, memory_config); },
py::arg("input_tensor"),
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

template <typename unary_operation_t>
void bind_unary_operation_with_scale_and_shift(py::module& module, const unary_operation_t& operation) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, scale, shift, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
Applies {0} to :attr:`input_tensor` element-wise.
.. math::
{0}(\\mathrm{{input\\_tensor}}_i)
Args:
* :attr:`input_tensor`
* :attr:`scale`
* :attr:`shift`
Keyword Args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Example::
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor)
)doc",
operation.name(),
operation.python_fully_qualified_name());

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const unary_operation_t& self,
const Tensor& input_tensor,
float scale,
float shift,
const std::optional<MemoryConfig>& memory_config) {
return self(input_tensor, scale, shift, memory_config);
},
py::arg("input_tensor"),
py::arg("scale")=1.0f/6.0f,
py::arg("shift")=0.5f,
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

template <typename unary_operation_t>
void bind_unary_operation_with_low_and_high(py::module& module, const unary_operation_t& operation) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, low, high, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
Applies {0} to :attr:`input_tensor` element-wise.
.. math::
{0}(\\mathrm{{input\\_tensor}}_i)
Args:
* :attr:`input_tensor`
* :attr:`low`
* :attr:`high`
Keyword Args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Example::
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor)
)doc",
operation.name(),
operation.python_fully_qualified_name());

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const unary_operation_t& self,
const Tensor& input_tensor,
float low,
float high,
const std::optional<MemoryConfig>& memory_config) {
return self(input_tensor, low, high, memory_config);
},
py::arg("input_tensor"),
py::arg("low") = -1.0f,
py::arg("high") = 1.0f,
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

template <typename unary_operation_t>
void bind_unary_operation_with_diag(py::module& module, const unary_operation_t& operation) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, diag, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
Applies {0} to :attr:`input_tensor` element-wise.
.. math::
{0}(\\mathrm{{input\\_tensor}}_i)
Args:
* :attr:`input_tensor`
* :attr:`diag`
Keyword Args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Example::
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor)
)doc",
operation.name(),
operation.python_fully_qualified_name());

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const unary_operation_t& self,
const Tensor& input_tensor,
int32_t diag,
const std::optional<MemoryConfig>& memory_config) { return self(input_tensor, diag, memory_config); },
py::arg("input_tensor"),
py::arg("diag"),
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

template <typename unary_operation_t>
Expand Down Expand Up @@ -229,6 +365,29 @@ void py_module(py::module& module) {

// Other unaries (composite operations)
detail::bind_softplus(module);

detail::bind_unary_operation(module, ttnn::acosh);
detail::bind_unary_operation(module, ttnn::asinh);
detail::bind_unary_operation(module, ttnn::atanh);
detail::bind_unary_operation(module, ttnn::cbrt);
detail::bind_unary_operation(module, ttnn::cosh);
detail::bind_unary_operation(module, ttnn::deg2rad);
detail::bind_unary_operation(module, ttnn::digamma);
detail::bind_unary_operation_with_scale_and_shift(module, ttnn::hardswish);
detail::bind_unary_operation_with_scale_and_shift(module, ttnn::hardsigmoid);
detail::bind_unary_operation_with_low_and_high(module, ttnn::hardtanh);
detail::bind_unary_operation(module, ttnn::lgamma);
detail::bind_unary_operation(module, ttnn::log1p);
detail::bind_unary_operation(module, ttnn::mish);
detail::bind_unary_operation(module, ttnn::multigammaln);
detail::bind_unary_operation(module, ttnn::rad2deg);
detail::bind_unary_operation(module, ttnn::sigmoid_accurate);
detail::bind_unary_operation(module, ttnn::sinh);
detail::bind_unary_operation(module, ttnn::softsign);
detail::bind_unary_operation(module, ttnn::swish);
detail::bind_unary_operation(module, ttnn::tanhshrink);
detail::bind_unary_operation_with_diag(module, ttnn::tril);
detail::bind_unary_operation_with_diag(module, ttnn::triu);
}

} // namespace unary
Expand Down
44 changes: 42 additions & 2 deletions ttnn/cpp/ttnn/decorators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,13 +365,53 @@ constexpr auto register_operation(const char* name) {
return operation_t<__COUNTER__, concrete_operation_t>{name};
}

#define TO_LAMBDA(function) ([](auto&&... args) { return function(std::forward<decltype(args)>(args)...); })

template <typename lambda_t>
constexpr auto register_operation(const char* name, const lambda_t& lambda) {
return lambda_operation_t<__COUNTER__, lambda_t>{name, lambda};
}

// This function is used to transform the arguments of a function before calling it
// where the lambda is applied to the type that matches T.
// Example: https://godbolt.org/z/3P9YedMdj
template <typename T, typename Func, typename Lambda, typename... Args>
constexpr auto transform_args_lambda(Func func, Lambda lambda, Args&&... args) -> decltype(auto) {
auto transformer = [lambda](auto&& arg) -> decltype(auto) {
if constexpr (std::is_same_v<T, std::decay_t<decltype(arg)>>) {
return lambda(std::forward<decltype(arg)>(arg));
} else {
return std::forward<decltype(arg)>(arg);
}
};

return func(transformer(std::forward<Args>(args))...);
}

template <typename T, typename Lambda>
auto transform_first_matching_arg(Lambda lambda) {
static_assert(!std::is_same<T, T>::value, "No matching type found");
}

template <typename T, typename Lambda, typename First, typename... Rest>
auto transform_first_matching_arg(Lambda lambda, First&& first, Rest&&... rest) {
if constexpr (std::is_same_v<T, std::decay_t<First>>) {
return lambda(std::forward<First>(first));
} else {
return transform_first_matching_arg<T>(lambda, std::forward<Rest>(rest)...);
}
}

#define TO_LAMBDA(function) ([](auto&&... args) { return function(std::forward<decltype(args)>(args)...); })

#define TO_LAMBDA_WITH_RESHAPE(function) \
([](auto&&... args) { \
const auto original_shape = ttnn::decorators::transform_first_matching_arg<Tensor>( \
[&](auto&& tensor) { return tensor.get_shape(); }, std::forward<decltype(args)>(args)...); \
return ttnn::reshape( \
ttnn::decorators::transform_args_lambda<Tensor>( \
function, [&](auto&& tensor) { return ttnn::unsqueeze_to_4D(tensor); }, args...), \
original_shape); \
})

} // namespace decorators

using ttnn::decorators::register_operation;
Expand Down
Loading

0 comments on commit ce75c4f

Please sign in to comment.