Skip to content

Commit

Permalink
#8683: Add Unary left shift support for WH_B0
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Jun 6, 2024
1 parent e5d966e commit 57c0e2f
Show file tree
Hide file tree
Showing 15 changed files with 242 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,8 @@ Tensor elementwise operations
.. autofunction:: tt_lib.tensor.heaviside

.. autofunction:: tt_lib.tensor.right_shift

.. autofunction:: tt_lib.tensor.left_shift

.. autofunction:: tt_lib.tensor.logaddexp

Expand Down
4 changes: 4 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/op_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,10 @@
"tt_op": tt_lib_ops.eltwise_right_shift,
"pytorch_op": pytorch_ops.right_shift,
},
"eltwise-left_shift": {
"tt_op": tt_lib_ops.eltwise_left_shift,
"pytorch_op": pytorch_ops.left_shift,
},
"eltwise-unary_ne": {
"tt_op": tt_lib_ops.eltwise_unary_ne,
"pytorch_op": pytorch_ops.unary_ne,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-FileCopyrightText: © 2023-24 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
from functools import partial
import tt_lib as ttl


from tests.tt_eager.python_api_testing.sweep_tests import (
comparison_funcs,
generation_funcs,
)
from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import (
run_single_pytorch_test,
)
from models.utility_functions import skip_for_grayskull

mem_configs = [
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
]


@pytest.mark.parametrize(
"scalar",
(3, 2, 1, 0),
)
@pytest.mark.parametrize(
"input_shapes",
[
[[1, 1, 32, 32]],
[[4, 3, 32, 32]],
[[2, 2, 32, 32]],
],
)
@pytest.mark.parametrize(
"dst_mem_config",
mem_configs,
)
@skip_for_grayskull("#TODO: GS implementation needs to be done")
class TestLeftShift:
def test_run_left_shift_op(
self,
scalar,
input_shapes,
dst_mem_config,
device,
):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-100, high=100), torch.int)
]
test_args = generation_funcs.gen_default_dtype_layout_device(input_shapes)[0]
test_args.update(
{
"value": scalar,
"dtype": [(ttl.tensor.DataType.INT32)],
}
)
test_args.update({"output_mem_config": dst_mem_config})
comparison_func = comparison_funcs.comp_equal

run_single_pytorch_test(
"eltwise-left_shift",
input_shapes,
datagen_func,
comparison_func,
device,
test_args,
)
6 changes: 6 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,12 @@ def right_shift(x, *args, **kwargs):
return result


def left_shift(x, *args, **kwargs):
value = kwargs.pop("value")
result = torch.bitwise_left_shift(x, value)
return result


def unary_ne(x, *args, **kwargs):
value = kwargs.pop("scalar")
result = torch.ne(x, value)
Expand Down
18 changes: 18 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,6 +1195,24 @@ def eltwise_right_shift(
return tt2torch_tensor(t1)


@setup_host_and_device
def eltwise_left_shift(
x,
*args,
value,
device,
dtype,
layout,
input_mem_config,
output_mem_config,
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttl.tensor.left_shift(t0, value, output_mem_config=output_mem_config)

return tt2torch_tensor(t1)


@setup_host_and_device
def eltwise_heaviside(
x,
Expand Down
6 changes: 6 additions & 0 deletions tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ void update_macro_defines(UnaryOpType op_type, std::map<std::string, std::string
case UnaryOpType::TYPECAST: defines["SFPU_OP_TYPECAST_INCLUDE"] = "1"; break;
case UnaryOpType::RIGHT_SHIFT: defines["SFPU_OP_RIGHT_SHIFT_INCLUDE"] = "1"; break;
case UnaryOpType::FLOOR: defines["SFPU_OP_FLOOR_INCLUDE"] = "1"; break;
case UnaryOpType::LEFT_SHIFT: defines["SFPU_OP_LEFT_SHIFT_INCLUDE"] = "1"; break;
default: defines["SFPU_OP_COMPUTE_KERNEL_API_INCLUDE"] = "1"; break;
};
}
Expand Down Expand Up @@ -119,6 +120,11 @@ std::pair<string, string> get_op_init_and_func_parameterized(
"right_shift_tile_init();",
fmt::format("right_shift_tile({}, {}u);", idst, std::to_string((uint)param0))};
break;
case UnaryOpType::LEFT_SHIFT:
op_init_and_name = {
"left_shift_tile_init();",
fmt::format("left_shift_tile({}, {}u);", idst, std::to_string((uint)param0))};
break;
case UnaryOpType::EXP:
op_init_and_name = {
fmt::format("exp_tile_init<{}u>();", std::to_string((uint32_t)param0)),
Expand Down
7 changes: 5 additions & 2 deletions tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ enum class UnaryOpType {
TILED_PROD,
TYPECAST,
RIGHT_SHIFT,
FLOOR
FLOOR,
LEFT_SHIFT
};

template <typename T>
Expand Down Expand Up @@ -108,7 +109,8 @@ bool is_parametrized_type(T val) {
case UnaryOpType::UNARY_GT:
case UnaryOpType::UNARY_LT:
case UnaryOpType::TYPECAST:
case UnaryOpType::RIGHT_SHIFT: return true;
case UnaryOpType::RIGHT_SHIFT:
case UnaryOpType::LEFT_SHIFT: return true;
default: return false;
}
return false;
Expand Down Expand Up @@ -370,6 +372,7 @@ constexpr auto prelu = leaky_relu;
constexpr auto elu = make_eltwise_unary_with_param<UnaryOpType::ELU>{};
constexpr auto heaviside = make_eltwise_unary_with_param<UnaryOpType::HEAVISIDE>{};
constexpr auto right_shift = make_eltwise_unary_with_param<UnaryOpType::RIGHT_SHIFT>{};
constexpr auto left_shift = make_eltwise_unary_with_param<UnaryOpType::LEFT_SHIFT>{};
constexpr auto unary_ne = make_eltwise_unary_with_param<UnaryOpType::UNARY_NE>{};
constexpr auto rsub = make_eltwise_unary_with_param<UnaryOpType::RSUB>{};
constexpr auto silu = make_eltwise_unary<UnaryOpType::SILU>{};
Expand Down
1 change: 0 additions & 1 deletion tt_eager/tt_dnn/op_library/prod/prod_op_all.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ Tensor prod_all(const Tensor& input, const MemoryConfig& output_mem_config ) {
}
//else --> GS Arch
return tt::numpy::prod_result_computation_GS<bfloat16>(result, result.get_dtype(), result.get_layout(), result.device(), output_mem_config);
return operation::run(Prod_op{.output_mem_config = output_mem_config}, {input}).at(0);
}

}
Expand Down
17 changes: 17 additions & 0 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,23 @@ namespace tt::tt_metal::detail {
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");
m_tensor.def("left_shift",left_shift,
py::arg("input").noconvert(),py::arg("shift_amt"),py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,R"doc(
Computes left shift of input tensor ``input`` by ``shift_amt`` bits. ``shift_amt`` range must be [0, 31]. Support provided only for Wormhole_B0.
Input tensor must have INT32 data type.
Output tensor will have INT32 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"input", "Input Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"shift_amt", "Number of shift bits", "int", "[0, 31]", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

detail::bind_unary_op_with_param(
m_tensor, "unary_ne", unary_ne,
py::arg("value"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@
#include "llk_math_eltwise_unary_sfpu_trigonometry.h"
#include "llk_math_eltwise_unary_sfpu_unary_comp.h"
#include "llk_math_eltwise_unary_sfpu_right_shift.h"
#include "llk_math_eltwise_unary_sfpu_left_shift.h"
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"
#include "noc_nonblocking_api.h"

using namespace sfpi;

namespace ckernel {
namespace sfpu {

template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
inline void calculate_left_shift(const uint shift_amt) {
#pragma GCC unroll 0
for (int d = 0; d < ITERATIONS; d++) {
vInt val = dst_reg[0];
vInt v = val;

val = val << shift_amt;
val = setsgn(val, v);
dst_reg[0] = val;

dst_reg++;
}
}

} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel_sfpu_left_shift.h"
#include "llk_math_eltwise_unary_sfpu_params.h"
#include "llk_math_eltwise_unary_sfpu_init.h"

namespace ckernel {

// New LLK SFPU APIs

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_left_shift_init() {
llk_math_eltwise_unary_sfpu_init<SfpuType::left_shift, APPROXIMATE>();
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_left_shift(uint dst_index, uint param0, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_left_shift<APPROXIMATE>,
dst_index,
vector_mode,
param0);
}

} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,6 @@ enum SfpuType {
tiled_prod,
right_shift,
floor,
left_shift,
unused,
};
46 changes: 46 additions & 0 deletions tt_metal/include/compute_kernel_api/eltwise_unary/left_shift.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once


#include "compute_kernel_api/common_globals.h"
#ifdef TRISC_MATH
#include "llk_math_eltwise_unary_sfpu_left_shift.h"
#define MAIN math_main()
#define MATH(x) x
#else
#define MATH(x)
#endif



namespace ckernel {

/**
* Performs element-wise left_shift computation on input x by y bits , where x is each element of a tile
* in DST register at index tile_index. The input must be of int data type only. The value is provided as const param0 The DST register buffer must be in
* acquired state via *acquire_dst* call. This call is blocking and is only
* available on the compute engine.
*
* Return value: None
*
* | Argument | Description | Type | Valid
* Range | Required |
* |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------|
* | idst | The index of the tile in DST register buffer to perform the computation on | uint32_t | Must be
* less than the size of the DST register buffer | True | | param0 | The value the output is if the input
* is greater than 0 | uint32_t | | True |
*/
ALWI void left_shift_tile(uint32_t idst, uint32_t param0) {
MATH((llk_math_eltwise_unary_sfpu_left_shift<APPROX>(idst, param0)));
}

/**
* Please refer to documentation for any_init.
*/
ALWI void left_shift_tile_init() { MATH((llk_math_eltwise_unary_sfpu_left_shift_init<APPROX>())); }


} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@
#include "compute_kernel_api/eltwise_unary/floor.h"
#endif

#if SFPU_OP_LEFT_SHIFT_INCLUDE
#include "compute_kernel_api/eltwise_unary/left_shift.h"
#endif

#if SFPU_OP_BINOP_WITH_SCALAR_INCLUDE
#include "compute_kernel_api/eltwise_unary/binop_with_scalar.h"
#endif
Expand Down

0 comments on commit 57c0e2f

Please sign in to comment.