From 57c0e2f7ef7e92c1d6c6cd59072e794e7a32ce2a Mon Sep 17 00:00:00 2001 From: Aswinmcw Date: Wed, 22 May 2024 07:56:16 +0000 Subject: [PATCH] #8683: Add Unary left shift support for WH_B0 --- docs/source/ttnn/ttnn/dependencies/tt_lib.rst | 2 + .../python_api_testing/sweep_tests/op_map.py | 4 ++ .../pytests/tt_dnn/test_left_shift.py | 71 +++++++++++++++++++ .../sweep_tests/pytorch_ops.py | 6 ++ .../sweep_tests/tt_lib_ops.py | 18 +++++ .../eltwise_unary/eltwise_unary_op.cpp | 6 ++ .../eltwise_unary/eltwise_unary_op.hpp | 7 +- .../tt_dnn/op_library/prod/prod_op_all.cpp | 1 - .../csrc/tt_lib_bindings_tensor_xary_ops.cpp | 17 +++++ .../metal/llk_api/llk_math_unary_sfpu_api.h | 1 + .../llk_sfpu/ckernel_sfpu_left_shift.h | 32 +++++++++ .../llk_math_eltwise_unary_sfpu_left_shift.h | 29 ++++++++ .../metal/llk_api/llk_sfpu_types.h | 1 + .../eltwise_unary/left_shift.h | 46 ++++++++++++ .../eltwise_unary/sfpu_split_includes.h | 4 ++ 15 files changed, 242 insertions(+), 3 deletions(-) create mode 100644 tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_left_shift.py create mode 100644 tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_left_shift.h create mode 100644 tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_left_shift.h create mode 100644 tt_metal/include/compute_kernel_api/eltwise_unary/left_shift.h diff --git a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst index bb3b528900d..512458eb61f 100644 --- a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst +++ b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst @@ -414,6 +414,8 @@ Tensor elementwise operations .. autofunction:: tt_lib.tensor.heaviside .. autofunction:: tt_lib.tensor.right_shift + +.. autofunction:: tt_lib.tensor.left_shift .. autofunction:: tt_lib.tensor.logaddexp diff --git a/tests/tt_eager/python_api_testing/sweep_tests/op_map.py b/tests/tt_eager/python_api_testing/sweep_tests/op_map.py index ae4b3524aa6..e96717c6cc9 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/op_map.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/op_map.py @@ -492,6 +492,10 @@ "tt_op": tt_lib_ops.eltwise_right_shift, "pytorch_op": pytorch_ops.right_shift, }, + "eltwise-left_shift": { + "tt_op": tt_lib_ops.eltwise_left_shift, + "pytorch_op": pytorch_ops.left_shift, + }, "eltwise-unary_ne": { "tt_op": tt_lib_ops.eltwise_unary_ne, "pytorch_op": pytorch_ops.unary_ne, diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_left_shift.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_left_shift.py new file mode 100644 index 00000000000..23dd6d236af --- /dev/null +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_left_shift.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: © 2023-24 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from functools import partial +import tt_lib as ttl + + +from tests.tt_eager.python_api_testing.sweep_tests import ( + comparison_funcs, + generation_funcs, +) +from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import ( + run_single_pytorch_test, +) +from models.utility_functions import skip_for_grayskull + +mem_configs = [ + ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM), + ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1), +] + + +@pytest.mark.parametrize( + "scalar", + (3, 2, 1, 0), +) +@pytest.mark.parametrize( + "input_shapes", + [ + [[1, 1, 32, 32]], + [[4, 3, 32, 32]], + [[2, 2, 32, 32]], + ], +) +@pytest.mark.parametrize( + "dst_mem_config", + mem_configs, +) +@skip_for_grayskull("#TODO: GS implementation needs to be done") +class TestLeftShift: + def test_run_left_shift_op( + self, + scalar, + input_shapes, + dst_mem_config, + device, + ): + datagen_func = [ + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-100, high=100), torch.int) + ] + test_args = generation_funcs.gen_default_dtype_layout_device(input_shapes)[0] + test_args.update( + { + "value": scalar, + "dtype": [(ttl.tensor.DataType.INT32)], + } + ) + test_args.update({"output_mem_config": dst_mem_config}) + comparison_func = comparison_funcs.comp_equal + + run_single_pytorch_test( + "eltwise-left_shift", + input_shapes, + datagen_func, + comparison_func, + device, + test_args, + ) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py index 41906b8fd18..336884e61cb 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py @@ -533,6 +533,12 @@ def right_shift(x, *args, **kwargs): return result +def left_shift(x, *args, **kwargs): + value = kwargs.pop("value") + result = torch.bitwise_left_shift(x, value) + return result + + def unary_ne(x, *args, **kwargs): value = kwargs.pop("scalar") result = torch.ne(x, value) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py index cd506e79e5f..14b180787d1 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py @@ -1195,6 +1195,24 @@ def eltwise_right_shift( return tt2torch_tensor(t1) +@setup_host_and_device +def eltwise_left_shift( + x, + *args, + value, + device, + dtype, + layout, + input_mem_config, + output_mem_config, + **kwargs, +): + t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) + t1 = ttl.tensor.left_shift(t0, value, output_mem_config=output_mem_config) + + return tt2torch_tensor(t1) + + @setup_host_and_device def eltwise_heaviside( x, diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp index 5dbf2d45c04..f327e839966 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp @@ -71,6 +71,7 @@ void update_macro_defines(UnaryOpType op_type, std::map get_op_init_and_func_parameterized( "right_shift_tile_init();", fmt::format("right_shift_tile({}, {}u);", idst, std::to_string((uint)param0))}; break; + case UnaryOpType::LEFT_SHIFT: + op_init_and_name = { + "left_shift_tile_init();", + fmt::format("left_shift_tile({}, {}u);", idst, std::to_string((uint)param0))}; + break; case UnaryOpType::EXP: op_init_and_name = { fmt::format("exp_tile_init<{}u>();", std::to_string((uint32_t)param0)), diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp index 910d0aa5681..ee81024cf74 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp @@ -80,7 +80,8 @@ enum class UnaryOpType { TILED_PROD, TYPECAST, RIGHT_SHIFT, - FLOOR + FLOOR, + LEFT_SHIFT }; template @@ -108,7 +109,8 @@ bool is_parametrized_type(T val) { case UnaryOpType::UNARY_GT: case UnaryOpType::UNARY_LT: case UnaryOpType::TYPECAST: - case UnaryOpType::RIGHT_SHIFT: return true; + case UnaryOpType::RIGHT_SHIFT: + case UnaryOpType::LEFT_SHIFT: return true; default: return false; } return false; @@ -370,6 +372,7 @@ constexpr auto prelu = leaky_relu; constexpr auto elu = make_eltwise_unary_with_param{}; constexpr auto heaviside = make_eltwise_unary_with_param{}; constexpr auto right_shift = make_eltwise_unary_with_param{}; +constexpr auto left_shift = make_eltwise_unary_with_param{}; constexpr auto unary_ne = make_eltwise_unary_with_param{}; constexpr auto rsub = make_eltwise_unary_with_param{}; constexpr auto silu = make_eltwise_unary{}; diff --git a/tt_eager/tt_dnn/op_library/prod/prod_op_all.cpp b/tt_eager/tt_dnn/op_library/prod/prod_op_all.cpp index 385321f1431..03029cd868b 100644 --- a/tt_eager/tt_dnn/op_library/prod/prod_op_all.cpp +++ b/tt_eager/tt_dnn/op_library/prod/prod_op_all.cpp @@ -52,7 +52,6 @@ Tensor prod_all(const Tensor& input, const MemoryConfig& output_mem_config ) { } //else --> GS Arch return tt::numpy::prod_result_computation_GS(result, result.get_dtype(), result.get_layout(), result.device(), output_mem_config); - return operation::run(Prod_op{.output_mem_config = output_mem_config}, {input}).at(0); } } diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp index a53b37791fd..bea3355b821 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp @@ -203,6 +203,23 @@ namespace tt::tt_metal::detail { "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" )doc"); + m_tensor.def("left_shift",left_shift, + py::arg("input").noconvert(),py::arg("shift_amt"),py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,R"doc( + Computes left shift of input tensor ``input`` by ``shift_amt`` bits. ``shift_amt`` range must be [0, 31]. Support provided only for Wormhole_B0. + + Input tensor must have INT32 data type. + + Output tensor will have INT32 data type. + + .. csv-table:: + :header: "Argument", "Description", "Data type", "Valid range", "Required" + + "input", "Input Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" + "shift_amt", "Number of shift bits", "int", "[0, 31]", "Yes" + "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + + )doc"); + detail::bind_unary_op_with_param( m_tensor, "unary_ne", unary_ne, py::arg("value"), diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h index cc5bbecd0fc..0ac5d901d2f 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h @@ -27,3 +27,4 @@ #include "llk_math_eltwise_unary_sfpu_trigonometry.h" #include "llk_math_eltwise_unary_sfpu_unary_comp.h" #include "llk_math_eltwise_unary_sfpu_right_shift.h" +#include "llk_math_eltwise_unary_sfpu_left_shift.h" diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_left_shift.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_left_shift.h new file mode 100644 index 00000000000..b2324d64633 --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_left_shift.h @@ -0,0 +1,32 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "noc_nonblocking_api.h" + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + +template +inline void calculate_left_shift(const uint shift_amt) { +#pragma GCC unroll 0 + for (int d = 0; d < ITERATIONS; d++) { + vInt val = dst_reg[0]; + vInt v = val; + + val = val << shift_amt; + val = setsgn(val, v); + dst_reg[0] = val; + + dst_reg++; + } +} + +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_left_shift.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_left_shift.h new file mode 100644 index 00000000000..c352923853d --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_left_shift.h @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel_sfpu_left_shift.h" +#include "llk_math_eltwise_unary_sfpu_params.h" +#include "llk_math_eltwise_unary_sfpu_init.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_unary_sfpu_left_shift_init() { + llk_math_eltwise_unary_sfpu_init(); +} + +template +inline void llk_math_eltwise_unary_sfpu_left_shift(uint dst_index, uint param0, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_left_shift, + dst_index, + vector_mode, + param0); +} + +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h index 6aa0a179972..372621b8737 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h @@ -77,5 +77,6 @@ enum SfpuType { tiled_prod, right_shift, floor, + left_shift, unused, }; diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/left_shift.h b/tt_metal/include/compute_kernel_api/eltwise_unary/left_shift.h new file mode 100644 index 00000000000..091382bd60a --- /dev/null +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/left_shift.h @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + + +#include "compute_kernel_api/common_globals.h" +#ifdef TRISC_MATH +#include "llk_math_eltwise_unary_sfpu_left_shift.h" +#define MAIN math_main() +#define MATH(x) x +#else +#define MATH(x) +#endif + + + +namespace ckernel { + +/** + * Performs element-wise left_shift computation on input x by y bits , where x is each element of a tile + * in DST register at index tile_index. The input must be of int data type only. The value is provided as const param0 The DST register buffer must be in + * acquired state via *acquire_dst* call. This call is blocking and is only + * available on the compute engine. + * + * Return value: None + * + * | Argument | Description | Type | Valid + * Range | Required | + * |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------| + * | idst | The index of the tile in DST register buffer to perform the computation on | uint32_t | Must be + * less than the size of the DST register buffer | True | | param0 | The value the output is if the input + * is greater than 0 | uint32_t | | True | + */ +ALWI void left_shift_tile(uint32_t idst, uint32_t param0) { + MATH((llk_math_eltwise_unary_sfpu_left_shift(idst, param0))); +} + +/** + * Please refer to documentation for any_init. + */ +ALWI void left_shift_tile_init() { MATH((llk_math_eltwise_unary_sfpu_left_shift_init())); } + + +} // namespace ckernel diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h b/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h index a0563fa817b..708c4905626 100644 --- a/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h @@ -76,6 +76,10 @@ #include "compute_kernel_api/eltwise_unary/floor.h" #endif +#if SFPU_OP_LEFT_SHIFT_INCLUDE +#include "compute_kernel_api/eltwise_unary/left_shift.h" +#endif + #if SFPU_OP_BINOP_WITH_SCALAR_INCLUDE #include "compute_kernel_api/eltwise_unary/binop_with_scalar.h" #endif