Skip to content

Commit

Permalink
#0: Revert overloaded ops (#11705)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw authored Aug 21, 2024
1 parent bb471ec commit 93fbac3
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 301 deletions.
4 changes: 4 additions & 0 deletions tests/ttnn/unit_tests/operations/test_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def test_level2_recip(bs, memcfg, dtype, device, function_level_defaults):
assert passing


@pytest.mark.skip(reason="This test is failing because ttnn.add doesn't support complex tensors")
@pytest.mark.parametrize(
"memcfg",
(
Expand Down Expand Up @@ -292,6 +293,7 @@ def test_level2_add(bs, memcfg, dtype, device, function_level_defaults):
assert passing


@pytest.mark.skip(reason="This test is failing because ttnn.sub doesn't support complex tensors")
@pytest.mark.parametrize(
"memcfg",
(
Expand Down Expand Up @@ -329,6 +331,7 @@ def test_level2_sub(bs, memcfg, dtype, device, function_level_defaults):
assert passing


@pytest.mark.skip(reason="This test is failing because ttnn.mul doesn't support complex tensors")
@pytest.mark.parametrize(
"memcfg",
(
Expand Down Expand Up @@ -366,6 +369,7 @@ def test_level2_mul(bs, memcfg, dtype, device, function_level_defaults):
assert passing


@pytest.mark.skip(reason="This test is failing because ttnn.div doesn't support complex tensors")
@pytest.mark.parametrize(
"memcfg",
(
Expand Down
151 changes: 5 additions & 146 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
#include "ttnn/device_operation.hpp"
#include "ttnn/operations/data_movement/repeat/repeat.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
#include "ttnn/operations/eltwise/complex/complex.hpp"
#include "ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp"

namespace ttnn::operations::binary {

namespace detail {
Expand Down Expand Up @@ -236,146 +235,6 @@ Tensor BinaryOperation<binary_op_type, in_place>::invoke(
}


template <BinaryOpType binary_op_type, bool in_place>
Tensor BinaryOperationOverload<binary_op_type, in_place>::invoke(
uint8_t queue_id,
const Tensor &input_tensor_a_arg,
const Tensor &input_tensor_b_arg,
const std::optional<const DataType> &output_dtype,
const std::optional<MemoryConfig> &memory_config,
std::optional<Tensor> optional_output_tensor,
std::optional<unary::FusedActivations> activations,
std::optional<unary::UnaryWithParam> input_tensor_a_activation) {

std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor_a_arg, input_tensor_b_arg}))};
operation::launch_op(
[queue_id, output_dtype, memory_config, optional_output_tensor, activations, input_tensor_a_activation](
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {

auto [input_tensor_a, input_tensor_b] = detail::preprocess_inputs<binary_op_type>(input_tensors[0], input_tensors[1]);

return {ttnn::prim::binary(
queue_id,
input_tensor_a,
input_tensor_b,
binary_op_type,
in_place,
output_dtype,
memory_config,
optional_output_tensor,
activations,
input_tensor_a_activation)};
},
{input_tensor_a_arg, input_tensor_b_arg},
output_tensors);

return output_tensors[0];
}

template <BinaryOpType binary_op_type, bool in_place>
Tensor BinaryOperationOverload<binary_op_type, in_place>::invoke(
const Tensor &input_tensor_a_arg,
const Tensor &input_tensor_b_arg,
const std::optional<const DataType> &output_dtype,
const std::optional<MemoryConfig> &memory_config,
std::optional<Tensor> optional_output_tensor,
std::optional<unary::FusedActivations> activations,
std::optional<unary::UnaryWithParam> input_tensor_a_activation) {
return invoke(
DefaultQueueId,
input_tensor_a_arg,
input_tensor_b_arg,
output_dtype,
memory_config,
optional_output_tensor,
activations,
input_tensor_a_activation);
}

// TODO: this case should use BinaryWithScalarProgramConfig and there should be a custom kernel to run this
// Currently, this is exactly how tt::tt_metal::add_unary works
template <BinaryOpType binary_op_type, bool in_place>
Tensor BinaryOperationOverload<binary_op_type, in_place>::invoke(
const ttnn::Tensor &input_tensor_a,
const float scalar,
const std::optional<const DataType> &dtype,
const std::optional<ttnn::MemoryConfig> &memory_config,
const std::optional<Tensor> &optional_output_tensor,
std::optional<unary::FusedActivations> activations,
std::optional<unary::UnaryWithParam> input_tensor_a_activation) {
return BinaryOperationOverload::invoke(
DefaultQueueId,
input_tensor_a,
scalar,
dtype,
memory_config,
optional_output_tensor,
activations,
input_tensor_a_activation);
}

template <BinaryOpType binary_op_type, bool in_place>
Tensor BinaryOperationOverload<binary_op_type, in_place>::invoke(
uint8_t queue_id,
const ttnn::Tensor &input_tensor_a,
const float scalar,
const std::optional<const DataType> &dtype,
const std::optional<ttnn::MemoryConfig> &memory_config,
const std::optional<Tensor> &optional_output_tensor,
std::optional<unary::FusedActivations> activations,
std::optional<unary::UnaryWithParam> input_tensor_a_activation) {
// Cast Float Scalar to a device tensor
auto host_buffer = owned_buffer::create<::bfloat16>(static_cast<std::size_t>(TILE_HEIGHT * TILE_WIDTH));
host_buffer[0] = scalar;
Tensor scalar_tensor_host = Tensor(
OwnedStorage{host_buffer},
ttnn::Shape(std::array<std::uint32_t, 2>{1, 1}, std::array<std::uint32_t, 2>{TILE_HEIGHT, TILE_WIDTH}),
DataType::BFLOAT16,
Layout::TILE);
Tensor scalar_tensor_device = scalar_tensor_host.to(input_tensor_a.device());
// TODO(arakhmati): #7637 pass in memory_config instead of operation::DEFAULT_OUTPUT_MEMORY_CONFIG
return BinaryOperationOverload::invoke(
input_tensor_a,
scalar_tensor_device,
dtype,
memory_config,
optional_output_tensor,
activations,
input_tensor_a_activation);
}

template <BinaryOpType binary_op_type, bool in_place>
ComplexTensor BinaryOperationOverload<binary_op_type, in_place>::invoke(
const ComplexTensor &input_a,
const ComplexTensor &input_b,
const ttnn::MemoryConfig &output_mem_config) {
if constexpr(binary_op_type == BinaryOpType::ADD) {
return ComplexTensor({ ttnn::add(input_a[0], input_b[0], std::nullopt, output_mem_config),
ttnn::add(input_a[1], input_b[1], std::nullopt, output_mem_config) });
}else if constexpr(binary_op_type == BinaryOpType::SUB) {
return ComplexTensor({ ttnn::subtract(input_a[0], input_b[0], std::nullopt, output_mem_config),
ttnn::subtract(input_a[1], input_b[1], std::nullopt, output_mem_config) });
}else if constexpr(binary_op_type == BinaryOpType::MUL) {
Tensor re_part = ttnn::subtract(
ttnn::multiply(input_a[0],input_b[0],std::nullopt,output_mem_config),
ttnn::multiply(input_a[1],input_b[1],std::nullopt,output_mem_config),
std::nullopt, output_mem_config);

Tensor im_part = ttnn::add(
ttnn::multiply(input_a[0],input_b[1],std::nullopt,output_mem_config),
ttnn::multiply(input_a[1],input_b[0],std::nullopt,output_mem_config),
std::nullopt, output_mem_config);

return ComplexTensor({ re_part, im_part });
}else if constexpr(binary_op_type == BinaryOpType::DIV_FAST) {
return ttnn::multiply( input_a, ttnn::reciprocal( input_b , output_mem_config ), output_mem_config );
}else {
TT_THROW("Unsupported operation (expected MUL or DIV_FAST or ADD or SUB)");
}
}

template <BinaryOpType binary_op_type>
Tensor RelationalBinary<binary_op_type>::invoke(
uint8_t queue_id,
Expand Down Expand Up @@ -495,19 +354,19 @@ Tensor InplaceLogicalBinary<binary_op_type>::invoke(
return BinaryOperation<binary_op_type, false>::invoke(input_tensor_a_arg, input_tensor_b_arg, std::nullopt, std::nullopt, input_tensor_a_arg, std::nullopt, std::nullopt);
}

template struct BinaryOperationOverload<BinaryOpType::ADD, false>;
template struct BinaryOperation<BinaryOpType::ADD, false>;
template struct BinaryOperation<BinaryOpType::ADD, true>;
template struct BinaryOperationOverload<BinaryOpType::SUB, false>;
template struct BinaryOperation<BinaryOpType::SUB, false>;
template struct BinaryOperation<BinaryOpType::SUB, true>;
template struct BinaryOperationOverload<BinaryOpType::MUL, false>;
template struct BinaryOperation<BinaryOpType::MUL, false>;
template struct BinaryOperation<BinaryOpType::MUL, true>;
template struct BinaryOperation<BinaryOpType::LOGICAL_AND, false>;
template struct BinaryOperation<BinaryOpType::LOGICAL_OR, false>;
template struct BinaryOperation<BinaryOpType::LDEXP, false>;
template struct BinaryOperation<BinaryOpType::LOGADDEXP, false>;
template struct BinaryOperation<BinaryOpType::LOGADDEXP2, false>;
template struct BinaryOperation<BinaryOpType::SQUARED_DIFFERENCE, false>;
template struct BinaryOperationOverload<BinaryOpType::DIV_FAST, false>;
template struct BinaryOperation<BinaryOpType::DIV_FAST, false>;
template struct BinaryOperation<BinaryOpType::BIAS_GELU, false>;

template struct RelationalBinary<BinaryOpType::EQ>;
Expand Down
64 changes: 8 additions & 56 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "ttnn/operations/eltwise/unary/common/unary_op_types.hpp"
#include "ttnn/operations/eltwise/binary/common/binary_op_types.hpp"
#include "device/binary_device_operation.hpp"
#include "ttnn/operations/eltwise/complex/complex.hpp"

namespace ttnn {

Expand Down Expand Up @@ -62,53 +61,6 @@ struct BinaryOperation {
std::optional<unary::UnaryWithParam> input_tensor_a_activation = std::nullopt);
};

template <BinaryOpType binary_op_type, bool in_place>
struct BinaryOperationOverload {
static Tensor invoke(
uint8_t queue_id,
const Tensor &input_tensor_a_arg,
const Tensor &input_tensor_b_arg,
const std::optional<const DataType> &output_dtype = std::nullopt,
const std::optional<MemoryConfig> &memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt,
std::optional<unary::FusedActivations> activations = std::nullopt,
std::optional<unary::UnaryWithParam> input_tensor_a_activation = std::nullopt);

static Tensor invoke(
const Tensor &input_tensor_a_arg,
const Tensor &input_tensor_b_arg,
const std::optional<const DataType> &output_dtype = std::nullopt,
const std::optional<MemoryConfig> &memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt,
std::optional<unary::FusedActivations> activations = std::nullopt,
std::optional<unary::UnaryWithParam> input_tensor_a_activation = std::nullopt);

// TODO: this case should use BinaryWithScalarProgramConfig and there should be a custom kernel to run this
// Currently, this is exactly how tt::tt_metal::add_unary works
static Tensor invoke(
const ttnn::Tensor &input_tensor_a,
const float scalar,
const std::optional<const DataType> &dtype = std::nullopt,
const std::optional<ttnn::MemoryConfig> &memory_config = std::nullopt,
const std::optional<Tensor> &optional_output_tensor = std::nullopt,
std::optional<unary::FusedActivations> activations = std::nullopt,
std::optional<unary::UnaryWithParam> input_tensor_a_activation = std::nullopt);

static Tensor invoke(
uint8_t queue_id,
const ttnn::Tensor &input_tensor_a,
const float scalar,
const std::optional<const DataType> &dtype = std::nullopt,
const std::optional<ttnn::MemoryConfig> &memory_config = std::nullopt,
const std::optional<Tensor> &optional_output_tensor = std::nullopt,
std::optional<unary::FusedActivations> activations = std::nullopt,
std::optional<unary::UnaryWithParam> input_tensor_a_activation = std::nullopt);

static ComplexTensor invoke(
const ComplexTensor &input_tensor_a_arg,
const ComplexTensor &input_tensor_b_arg,
const MemoryConfig &memory_config);
};

template <BinaryOpType binary_op_type>
struct RelationalBinary {
Expand Down Expand Up @@ -181,21 +133,21 @@ struct InplaceLogicalBinary {
} // binary
} // operations

constexpr auto add = ttnn::register_operation<
constexpr auto add = ttnn::register_operation_with_auto_launch_op<
"ttnn::add",
operations::binary::BinaryOperationOverload<operations::binary::BinaryOpType::ADD, false>>();
operations::binary::BinaryOperation<operations::binary::BinaryOpType::ADD, false>>();
constexpr auto add_ = ttnn::register_operation_with_auto_launch_op<
"ttnn::add_",
operations::binary::BinaryOperation<operations::binary::BinaryOpType::ADD, true>>();
constexpr auto subtract = ttnn::register_operation<
constexpr auto subtract = ttnn::register_operation_with_auto_launch_op<
"ttnn::subtract",
operations::binary::BinaryOperationOverload<operations::binary::BinaryOpType::SUB, false>>();
operations::binary::BinaryOperation<operations::binary::BinaryOpType::SUB, false>>();
constexpr auto subtract_ = ttnn::register_operation_with_auto_launch_op<
"ttnn::subtract_",
operations::binary::BinaryOperation<operations::binary::BinaryOpType::SUB, true>>();
constexpr auto multiply = ttnn::register_operation<
constexpr auto multiply = ttnn::register_operation_with_auto_launch_op<
"ttnn::multiply",
operations::binary::BinaryOperationOverload<operations::binary::BinaryOpType::MUL, false>>();
operations::binary::BinaryOperation<operations::binary::BinaryOpType::MUL, false>>();
constexpr auto multiply_ = ttnn::register_operation_with_auto_launch_op<
"ttnn::multiply_",
operations::binary::BinaryOperation<operations::binary::BinaryOpType::MUL, true>>();
Expand Down Expand Up @@ -235,9 +187,9 @@ constexpr auto logaddexp2 = ttnn::register_operation_with_auto_launch_op<
constexpr auto squared_difference = ttnn::register_operation_with_auto_launch_op<
"ttnn::squared_difference",
operations::binary::BinaryOperation<operations::binary::BinaryOpType::SQUARED_DIFFERENCE, false>>();
constexpr auto divide = ttnn::register_operation<
constexpr auto divide = ttnn::register_operation_with_auto_launch_op<
"ttnn::divide",
operations::binary::BinaryOperationOverload<operations::binary::BinaryOpType::DIV_FAST, false>>();
operations::binary::BinaryOperation<operations::binary::BinaryOpType::DIV_FAST, false>>();
constexpr auto gt_ = ttnn::register_operation_with_auto_launch_op<
"ttnn::gt_",
operations::binary::InplaceRelationalBinary<operations::binary::BinaryOpType::GT>>();
Expand Down
Loading

0 comments on commit 93fbac3

Please sign in to comment.