diff --git a/tests/ttnn/unit_tests/operations/test_batch_norm.py b/tests/ttnn/unit_tests/operations/test_batch_norm.py index ff85ef3a56a4..93213b090ef9 100644 --- a/tests/ttnn/unit_tests/operations/test_batch_norm.py +++ b/tests/ttnn/unit_tests/operations/test_batch_norm.py @@ -40,7 +40,7 @@ @pytest.mark.parametrize("weight", [True, False]) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05]) -@pytest.mark.parametrize("momentum", [0.1, 0.0, 1.0, 2.3]) +@pytest.mark.parametrize("momentum", [0.0, 0.1, 0.5]) def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias, eps, momentum, device): in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True) mean_data, mean_tensor = ( diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 22b6f4375adf..8fa807f529c4 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -304,6 +304,8 @@ set(TTNN_OP_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_device_operation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/groupnorm/groupnorm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/groupnorm/groupnorm_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/normalization/groupnorm/device/groupnorm_op.cpp diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp index 111c7aa15dcf..a64a9eef7d82 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp @@ -7,6 +7,7 @@ #include "device/batch_norm_device_operation.hpp" #include "ttnn/operations/moreh/moreh_mean/device/moreh_mean_device_operation.hpp" #include "ttnn/operations/eltwise/unary/device/unary_composite_op.hpp" +#include "device/running_statistics_device_operation.hpp" using namespace tt::tt_metal; @@ -42,6 +43,8 @@ Tensor BatchNorm::invoke( Tensor mean_sq = mean_NHW(ttnn::square(input, memory_config), memory_config); Tensor batch_var = ttnn::subtract(mean_sq, ttnn::square(batch_mean, memory_config), std::nullopt, memory_config); + Tensor stats = + ttnn::prim::running_statistics(batch_mean, batch_var, momentum, running_mean, running_var, memory_config); return ttnn::prim::batch_norm(input, batch_mean, batch_var, eps, weight, bias, output, memory_config); } TT_FATAL( diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/running_statistics_kernel.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/running_statistics_kernel.cpp new file mode 100644 index 000000000000..781edcfdfa52 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/running_statistics_kernel.cpp @@ -0,0 +1,60 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "compute_kernel_api/eltwise_binary.h" +#include "compute_kernel_api/tile_move_copy.h" +#include "dprint.h" +#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp" + +namespace NAMESPACE { +void MAIN { + uint32_t num_tiles = get_arg_val(0); + constexpr uint32_t old_running_mean_has_value = get_compile_time_arg_val(0) == 1; + constexpr uint32_t old_running_var_has_value = get_compile_time_arg_val(1) == 1; + + constexpr auto cb_batch_mean = tt::CBIndex::c_0; // batch mean + constexpr auto cb_batch_var = tt::CBIndex::c_1; // batch var + constexpr auto cb_out0 = tt::CBIndex::c_2; + constexpr auto cb_old_running_mean = tt::CBIndex::c_3; // old running mean tensor + constexpr auto cb_old_running_var = tt::CBIndex::c_4; // old running var tensor + constexpr auto cb_updated_running_mean = tt::CBIndex::c_27; // updated running mean tensor + constexpr auto cb_updated_running_var = tt::CBIndex::c_28; // updated running var tensor + constexpr auto cb_momentum = tt::CBIndex::c_5; // momentum + constexpr auto cb_one = tt::CBIndex::c_6; // stores 1 + constexpr auto cb_tmp1 = tt::CBIndex::c_21; // tmp 1 + constexpr auto cb_tmp2 = tt::CBIndex::c_22; // tmp 2 + constexpr auto cb_tmp3 = tt::CBIndex::c_23; // tmp 3 + + binary_op_init_common(cb_batch_mean, cb_batch_var, cb_out0); + constexpr uint32_t onetile = 1; + + for (uint32_t tile_id = 0; tile_id < num_tiles; ++tile_id) { + tile_regs_acquire(); + // updated_running_stat = (1 − momentum) × running_stat + momentum × batch_stat + cb_wait_front(cb_one, 1); + cb_wait_front(cb_momentum, 1); + + if constexpr (old_running_mean_has_value) { + sub_tiles_to_cb(cb_one, cb_momentum, cb_tmp1, 0, 0, 0, 0); // 1 - momentum + mul_tiles_to_cb(cb_momentum, cb_batch_mean, cb_tmp2, 0, 0, 0, 1); // momentum * batch stat + mul_tiles_to_cb(cb_tmp1, cb_old_running_mean, cb_tmp3, 0, 0, 1, 1); // cb_tmp1 * running stats + add_tiles_to_cb(cb_tmp2, cb_tmp3, cb_updated_running_mean, 0, 0, 1, 1); // cb_tmp2 * cb_tmp3 + } + if constexpr (old_running_var_has_value) { + sub_tiles_to_cb(cb_one, cb_momentum, cb_tmp1, 0, 0, 0, 0); // 1 - momentum + mul_tiles_to_cb(cb_momentum, cb_batch_var, cb_tmp2, 0, 0, 0, 1); // momentum * batch stat + mul_tiles_to_cb(cb_tmp1, cb_old_running_var, cb_tmp3, 0, 0, 1, 1); // cb_tmp1 * running stats + add_tiles_to_cb(cb_tmp2, cb_tmp3, cb_updated_running_var, 0, 0, 1, 1); // cb_tmp2 * cb_tmp3 + } + cb_pop_front(cb_one, 1); + cb_pop_front(cb_momentum, 1); + tile_regs_commit(); + tile_regs_wait(); + pack_tile(0, cb_out0); + tile_regs_release(); + cb_push_back(cb_out0, 1); + } +} +} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_running_statistics.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_running_statistics.cpp new file mode 100644 index 000000000000..9c80f01f2f76 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_running_statistics.cpp @@ -0,0 +1,72 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" +#include "debug/dprint.h" +#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" +#include "cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/fill_tile_utils.hpp" + +void kernel_main() { + const auto momentum = get_arg_val(0); + uint32_t src_addr = get_arg_val(1); // input tensor + uint32_t start_tile_id = get_arg_val(2); + uint32_t num_tiles = get_arg_val(3); + uint32_t HtWt = get_arg_val(4); + uint32_t n_stride = get_arg_val(5); + uint32_t c_stride = get_arg_val(6); + uint32_t N = get_arg_val(7); + uint32_t C = get_arg_val(8); + + constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1; + + constexpr auto cb_id_src = tt::CBIndex::c_0; + constexpr auto cb_id_momentum = tt::CBIndex::c_5; + constexpr auto cb_id_one = tt::CBIndex::c_6; + constexpr uint32_t onetile = 1; + + const uint32_t src_tile_bytes = get_tile_size(cb_id_src); + const DataFormat src_data_format = get_dataformat(cb_id_src); + const InterleavedAddrGenFast src = { + .bank_base_address = src_addr, .page_size = src_tile_bytes, .data_format = src_data_format}; + + uint32_t tiles_per_batch = HtWt * C; + uint32_t start_n = start_tile_id / tiles_per_batch; + uint32_t start_remaining = start_tile_id % tiles_per_batch; + uint32_t start_c = start_remaining / HtWt; + uint32_t start_t = start_remaining % HtWt; + + // this is the INPUT tile offset + uint32_t tile_offset = start_n * n_stride + start_c * c_stride + start_t; + + uint32_t next_channel_shift = c_stride - HtWt; + uint32_t next_batch_shift = n_stride - c_stride * C; + + union { + float f; + uint32_t u; + } scalar; + scalar.f = 1.0f; + fill_cb_with_value(cb_id_one, scalar.u); + + cb_reserve_back(cb_id_momentum, onetile); + fill_with_val_bfloat16(cb_id_momentum, momentum); + cb_push_back(cb_id_momentum, onetile); + + uint32_t num_tiles_read = 0; + for (uint32_t n = start_n; n < N && num_tiles_read < num_tiles; ++n, start_c = 0) { + for (uint32_t c = start_c; c < C && num_tiles_read < num_tiles; ++c, start_t = 0) { + for (uint32_t t = start_t; t < HtWt && num_tiles_read < num_tiles; ++t, ++num_tiles_read, ++tile_offset) { + cb_reserve_back(cb_id_src, onetile); + uint32_t l1_write_addr_src = get_write_ptr(cb_id_src); + noc_async_read_tile(tile_offset, src, l1_write_addr_src); + noc_async_read_barrier(); + cb_push_back(cb_id_src, onetile); + } + tile_offset += next_channel_shift; + } + tile_offset += next_batch_shift; + } +} diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_running_statistics.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_running_statistics.cpp new file mode 100644 index 000000000000..61d0f8ea04c6 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_running_statistics.cpp @@ -0,0 +1,136 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" +#include "cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/fill_tile_utils.hpp" + +void kernel_main() { + uint32_t src_addr = get_arg_val(0); // batch_var + uint32_t old_running_mean_addr = get_arg_val(1); // old running_mean + uint32_t old_running_var_addr = get_arg_val(2); // ols running_var + uint32_t dst_addr = get_arg_val(3); // output + uint32_t start_tile_id = get_arg_val(4); + uint32_t num_tiles = get_arg_val(5); + uint32_t HtWt = get_arg_val(6); + uint32_t n_stride = get_arg_val(7); + uint32_t c_stride = get_arg_val(8); + uint32_t N = get_arg_val(9); + uint32_t C = get_arg_val(10); + + constexpr uint32_t onetile = 1; + + constexpr auto cb_id_src = tt::CBIndex::c_1; + constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1; + const uint32_t src_tile_bytes = get_tile_size(cb_id_src); + const DataFormat src_data_format = get_dataformat(cb_id_src); + + const InterleavedAddrGenFast src = { + .bank_base_address = src_addr, .page_size = src_tile_bytes, .data_format = src_data_format}; + + constexpr auto cb_id_dst = tt::CBIndex::c_2; + constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1; + const uint32_t dst_tile_bytes = get_tile_size(cb_id_dst); + const DataFormat dst_data_format = get_dataformat(cb_id_dst); + + const InterleavedAddrGenFast dst = { + .bank_base_address = dst_addr, .page_size = dst_tile_bytes, .data_format = dst_data_format}; + + // old running mean + constexpr auto cb_id_old_running_mean = tt::CBIndex::c_3; + constexpr bool old_running_mean_is_dram = get_compile_time_arg_val(2) == 1; + const uint32_t old_running_mean_tile_bytes = get_tile_size(cb_id_old_running_mean); + const DataFormat old_running_mean_data_format = get_dataformat(cb_id_old_running_mean); + + const InterleavedAddrGenFast old_running_mean = { + .bank_base_address = old_running_mean_addr, + .page_size = old_running_mean_tile_bytes, + .data_format = old_running_mean_data_format}; + + // old running var + constexpr auto cb_id_old_running_var = tt::CBIndex::c_4; + constexpr bool old_running_var_is_dram = get_compile_time_arg_val(3) == 1; + const uint32_t old_running_var_tile_bytes = get_tile_size(cb_id_old_running_var); + const DataFormat old_running_var_data_format = get_dataformat(cb_id_old_running_var); + + const InterleavedAddrGenFast old_running_var = { + .bank_base_address = old_running_var_addr, + .page_size = old_running_var_tile_bytes, + .data_format = old_running_var_data_format}; + + constexpr bool old_running_mean_has_value = get_compile_time_arg_val(4) == 1; + constexpr bool old_running_var_has_value = get_compile_time_arg_val(5) == 1; + constexpr auto cb_id_updated_running_mean = tt::CBIndex::c_27; + constexpr auto cb_id_updated_running_var = tt::CBIndex::c_28; + + uint32_t tiles_per_batch = HtWt * C; + uint32_t start_n = start_tile_id / tiles_per_batch; + uint32_t start_remaining = start_tile_id % tiles_per_batch; + uint32_t start_c = start_remaining / HtWt; + uint32_t start_t = start_remaining % HtWt; + + // this is the INPUT tile offset + uint32_t tile_offset = start_n * n_stride + start_c * c_stride + start_t; + uint32_t next_channel_shift = c_stride - HtWt; + uint32_t next_batch_shift = n_stride - c_stride * C; + + uint32_t num_tiles_written = 0; + for (uint32_t n = start_n; n < N && num_tiles_written < num_tiles; ++n, start_c = 0) { + for (uint32_t c = start_c; c < C && num_tiles_written < num_tiles; ++c, start_t = 0) { + for (uint32_t t = start_t; t < HtWt && num_tiles_written < num_tiles; ++t, ++num_tiles_written) { + // read a tile from src + cb_reserve_back(cb_id_src, onetile); + uint32_t l1_write_addr = get_write_ptr(cb_id_src); + noc_async_read_tile(tile_offset, src, l1_write_addr); + noc_async_read_barrier(); + cb_push_back(cb_id_src, onetile); + + if constexpr (old_running_mean_has_value) { + // read data + cb_reserve_back(cb_id_old_running_mean, onetile); + uint32_t l1_old_running_mean_write_addr = get_write_ptr(cb_id_old_running_mean); + noc_async_read_tile(tile_offset, old_running_mean, l1_old_running_mean_write_addr); + noc_async_read_barrier(); + fill_tile_with_first_element_bfloat16(cb_id_old_running_mean); + cb_push_back(cb_id_old_running_mean, onetile); + + // write data + cb_wait_front(cb_id_updated_running_mean, onetile); + uint32_t l1_write_updated_mean_addr = get_read_ptr(cb_id_updated_running_mean); + noc_async_write_tile(tile_offset, old_running_mean, l1_write_updated_mean_addr); + noc_async_write_barrier(); + cb_pop_front(cb_id_updated_running_mean, onetile); + } + + if constexpr (old_running_var_has_value) { + // read data + cb_reserve_back(cb_id_old_running_var, onetile); + uint32_t l1_old_running_var_write_addr = get_write_ptr(cb_id_old_running_var); + noc_async_read_tile(tile_offset, old_running_var, l1_old_running_var_write_addr); + noc_async_read_barrier(); + fill_tile_with_first_element_bfloat16(cb_id_old_running_var); + cb_push_back(cb_id_old_running_var, onetile); + + // write data + cb_wait_front(cb_id_updated_running_var, onetile); + uint32_t l1_write_updated_var_addr = get_read_ptr(cb_id_updated_running_var); + noc_async_write_tile(tile_offset, old_running_var, l1_write_updated_var_addr); + noc_async_write_barrier(); + cb_pop_front(cb_id_updated_running_var, onetile); + } + ++tile_offset; + + // write a tile to dst, since the dst shape is full, the tile offset simply grows linearly + cb_wait_front(cb_id_dst, onetile); + uint32_t l1_read_addr = get_read_ptr(cb_id_dst); + noc_async_write_tile(start_tile_id + num_tiles_written, dst, l1_read_addr); + noc_async_write_barrier(); + cb_pop_front(cb_id_dst, onetile); + } + tile_offset += next_channel_shift; + } + tile_offset += next_batch_shift; + } +} diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_device_operation.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_device_operation.cpp new file mode 100644 index 000000000000..d64f303e5980 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_device_operation.cpp @@ -0,0 +1,117 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "running_statistics_device_operation.hpp" + +#include "ttnn/operations/moreh/moreh_helper_functions.hpp" +#include "ttnn/tensor/tensor.hpp" + +namespace ttnn::operations::normalization { +void RunningStatistics::validate_tensors( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + const auto& [batch_mean, batch_var, running_mean, running_var] = tensor_args; + + check_tensor(batch_mean, "running_statistics", "batch_mean"); + check_tensor(batch_var, "running_statistics", "batch_var"); + check_tensor(running_mean, "running_statistics", "running_mean"); + check_tensor(running_var, "running_statistics", "running_var"); + + // mean (1, C, 1, 1) + auto C = batch_mean.get_logical_shape()[1]; + // var (1, C, 1, 1) + TT_FATAL(batch_var.get_logical_shape()[1] == C, "batch_var_shape[1] must be the same as input's channel size."); + + // running_mean (1, C, 1, 1) + if (running_mean.has_value()) { + TT_FATAL( + running_mean.value().get_logical_shape()[1] == C, + "running_mean_shape[1] must be the same as input's channel size."); + TT_FATAL( + running_mean.value().get_logical_shape()[1] == C, + "running_mean_shape[1] must be the same as input's channel size."); + } + + // running_var (1, C, 1, 1) + if (running_var.has_value()) { + TT_FATAL( + running_var.value().get_logical_shape()[1] == C, + "running_var_shape[1] must be the same as input's channel size."); + TT_FATAL( + running_var.value().get_logical_shape()[1] == C, + "running_var_shape[1] must be the same as input's channel size."); + } +} + +RunningStatistics::program_factory_t RunningStatistics::select_program_factory( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + return RunningStatisticsFactory(); +} + +void RunningStatistics::validate_on_program_cache_miss( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + const auto& [batch_mean, batch_var, running_mean, running_var] = tensor_args; + + TT_FATAL(batch_mean.get_layout() == Layout::TILE, "batch_mean tensor must be tilized"); + TT_FATAL( + batch_mean.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, + "batch_mean tensor must be interleaved"); + + TT_FATAL(batch_var.get_layout() == Layout::TILE, "batch_var tensor must be tilized"); + TT_FATAL( + batch_var.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, + "batch_var tensor must be interleaved"); + + if (running_mean.has_value()) { + TT_FATAL(running_mean.value().get_layout() == Layout::TILE, "running_mean tensor must be tilized"); + TT_FATAL( + running_mean.value().memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, + "running_mean tensor must be interleaved"); + } + + if (running_var.has_value()) { + TT_FATAL(running_var.value().get_layout() == Layout::TILE, "running_var tensor must be tilized"); + TT_FATAL( + running_var.value().memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, + "running_var tensor must be interleaved"); + } + + validate_tensors(operation_attributes, tensor_args); +}; + +void RunningStatistics::validate_on_program_cache_hit( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + validate_tensors(operation_attributes, tensor_args); +}; + +DataType RunningStatistics::operation_attributes_t::get_dtype() const { + return this->dtype.value_or(this->input_dtype); +} + +RunningStatistics::spec_return_value_t RunningStatistics::compute_output_specs( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + using namespace tt::constants; + const auto output_shape = tensor_args.batch_mean.get_logical_shape(); + return TensorSpec( + output_shape, + TensorLayout(operation_attributes.get_dtype(), PageConfig(Layout::TILE), operation_attributes.memory_config)); +} + +RunningStatistics::tensor_return_value_t RunningStatistics::create_output_tensors( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + return create_device_tensor( + compute_output_specs(operation_attributes, tensor_args), tensor_args.batch_mean.device()); +} + +std::tuple RunningStatistics::invoke( + const Tensor& batch_mean, + const Tensor& batch_var, + const float momentum, + std::optional running_mean, + std::optional running_var, + const std::optional& memory_config) { + operation_attributes_t operation_attributes{momentum, memory_config.value_or(batch_mean.memory_config())}; + tensor_args_t tensor_args{batch_mean, batch_var, std::move(running_mean), std::move(running_var)}; + return {operation_attributes, tensor_args}; +} +} // namespace ttnn::operations::normalization diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_device_operation.hpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_device_operation.hpp new file mode 100644 index 000000000000..6de62be91ae9 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_device_operation.hpp @@ -0,0 +1,75 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/decorators.hpp" +#include "ttnn/device_operation.hpp" +#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" + +namespace ttnn::operations::normalization { +struct RunningStatistics { + struct operation_attributes_t { + const float momentum; + const MemoryConfig memory_config; + + DataType input_dtype; + std::optional dtype; + DataType get_dtype() const; + }; + + struct tensor_args_t { + const Tensor& batch_mean; + const Tensor& batch_var; + std::optional running_mean; + std::optional running_var; + }; + + using spec_return_value_t = TensorSpec; + using tensor_return_value_t = Tensor; + + struct RunningStatisticsFactory { + struct shared_variables_t { + tt::tt_metal::KernelHandle reader_kernel_id; + tt::tt_metal::KernelHandle writer_kernel_id; + tt::tt_metal::KernelHandle compute_kernel_id; + CoreCoord compute_with_storage_grid_size; + }; + + using cached_program_t = ttnn::device_operation::CachedProgram; + + static cached_program_t create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& output); + + static void override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& output); + }; + + using program_factory_t = std::variant; + + static void validate_tensors(const operation_attributes_t&, const tensor_args_t&); + static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&); + static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); + static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); + static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&); + static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&); + static std::tuple invoke( + const Tensor& batch_mean, + const Tensor& batch_var, + const float momentum, + std::optional running_mean, + std::optional running_var, + const std::optional& memory_config); +}; +} // namespace ttnn::operations::normalization + +namespace ttnn::prim { +constexpr auto running_statistics = + ttnn::register_operation<"ttnn::prim::running_statistics", ttnn::operations::normalization::RunningStatistics>(); +} diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_program_factory.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_program_factory.cpp new file mode 100644 index 000000000000..9b79757665dc --- /dev/null +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_program_factory.cpp @@ -0,0 +1,317 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "running_statistics_device_operation.hpp" + +#include +#include "ttnn/operations/cb_utils.hpp" +#include + +namespace { +namespace CMAKE_UNIQUE_NAMESPACE { + +using namespace ttnn::operations::normalization; + +std::tuple extract_shape_dims(const Tensor& x) { + const auto& shape = x.padded_shape(); + const auto& tile = x.tensor_spec().tile(); + return {shape[-4], shape[-3], shape[-2] / tile.get_height(), shape[-1] / tile.get_width()}; +} + +template +void set_or_update_runtime_arguments( + Program& program, + KernelHandle reader_kernel_id, + KernelHandle writer_kernel_id, + KernelHandle compute_kernel_id, + CoreCoord compute_with_storage_grid_size, + const RunningStatistics::operation_attributes_t& operation_attributes, + const RunningStatistics::tensor_args_t& tensor_args, + RunningStatistics::tensor_return_value_t& c, + F handle_args) { + const auto& [a, b, d, e] = tensor_args; + const auto momentum = operation_attributes.momentum; + + const bool running_mean_has_value = d.has_value(); + const bool running_var_has_value = e.has_value(); + + const auto ashape = a.padded_shape(); + const auto bshape = b.padded_shape(); + const auto cshape = c.padded_shape(); + + const auto [aN, aC, aHt, aWt] = extract_shape_dims(a); + const auto [bN, bC, bHt, bWt] = extract_shape_dims(b); + const auto [cN, cC, cHt, cWt] = extract_shape_dims(c); + + uint32_t num_output_tiles = c.volume() / c.tensor_spec().tile().get_tile_hw(); + + constexpr bool row_major = true; + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + uint32_t num_cores_total = num_cores_x * num_cores_y; + auto all_device_cores = CoreRange({0, 0}, {num_cores_x - 1, num_cores_y - 1}); + auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = + tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_output_tiles, row_major); + + auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major); + for (uint32_t i = 0, start_tile_id = 0; i < num_cores_total; i++) { + const auto& core = cores[i]; + + uint32_t num_tiles_per_core; + if (core_group_1.contains(core)) { + num_tiles_per_core = num_tiles_per_core_group_1; + } else if (core_group_2.contains(core)) { + num_tiles_per_core = num_tiles_per_core_group_2; + } else { + handle_args(program, reader_kernel_id, core, std::array{0}); + handle_args(program, writer_kernel_id, core, std::array{0}); + handle_args(program, compute_kernel_id, core, std::array{0}); + continue; + } + + uint32_t cHtWt = cHt * cWt; + class bfloat16 bfloat_scalar_momentum(momentum); + uint32_t packed_scalar_momentum = + pack_two_bfloat16_into_uint32({bfloat_scalar_momentum, bfloat_scalar_momentum}); + std::array reader_runtime_args = { + packed_scalar_momentum, + a.buffer()->address(), + start_tile_id, + num_tiles_per_core, + cHtWt, + aHt * aWt * aC * (aN > 1), + aHt * aWt * (aC > 1), + cN, + cC, + cHt, + cWt}; + handle_args(program, reader_kernel_id, core, reader_runtime_args); + + const auto running_mean_addr = running_mean_has_value ? d->buffer()->address() : 0; + const auto running_var_addr = running_var_has_value ? e->buffer()->address() : 0; + std::array writer_runtime_args = { + b.buffer()->address(), // batch var + running_mean_addr, // old running mean + running_var_addr, // old running var + c.buffer()->address(), // output + start_tile_id, + num_tiles_per_core, + cHtWt, + bHt * bWt * bC * (bN > 1), + bHt * bWt * (bC > 1), + cN, + cC, + cHt, + cWt}; + handle_args(program, writer_kernel_id, core, writer_runtime_args); + + auto counter = start_tile_id % cHtWt; + auto freq = cHtWt; + + std::array compute_runtime_args = {num_tiles_per_core, freq, counter}; + handle_args(program, compute_kernel_id, core, compute_runtime_args); + + start_tile_id += num_tiles_per_core; + } +} + +} // namespace CMAKE_UNIQUE_NAMESPACE +} // namespace + +namespace ttnn::operations::normalization { +RunningStatistics::RunningStatisticsFactory::cached_program_t RunningStatistics::RunningStatisticsFactory::create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& output) { + using namespace tt; + using namespace tt::tt_metal; + + const auto& [a, b, d, e] = tensor_args; + + auto program = CreateProgram(); + + auto* device = a.device(); + + const bool running_mean_has_value = d.has_value(); + const bool running_var_has_value = e.has_value(); + + auto a_data_format = datatype_to_dataformat_converter(a.get_dtype()); + auto b_data_format = datatype_to_dataformat_converter(b.get_dtype()); + auto c_data_format = datatype_to_dataformat_converter(output.get_dtype()); + auto d_data_format = + running_mean_has_value ? datatype_to_dataformat_converter(d->get_dtype()) : DataFormat::Float16_b; + auto e_data_format = + running_var_has_value ? datatype_to_dataformat_converter(e->get_dtype()) : DataFormat::Float16_b; + + uint32_t a_single_tile_size = tt_metal::detail::TileSize(a_data_format); + uint32_t b_single_tile_size = tt_metal::detail::TileSize(b_data_format); + uint32_t c_single_tile_size = tt_metal::detail::TileSize(c_data_format); + uint32_t d_single_tile_size = tt_metal::detail::TileSize(d_data_format); + uint32_t e_single_tile_size = tt_metal::detail::TileSize(e_data_format); + + uint32_t num_output_tiles = output.volume() / output.tensor_spec().tile().get_tile_hw(); + + // we parallelize the computation across the output tiles + constexpr bool row_major = true; + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + auto all_device_cores = CoreRange({0, 0}, {num_cores_x - 1, num_cores_y - 1}); + + // Number of tiles to store per input CB (double buffer) + constexpr uint32_t num_tiles_per_cb = 2; + uint32_t b_num_tiles_per_cb = num_tiles_per_cb; + + // Input buffers + auto [a_cb, a_cb_handle] = create_cb( + tt::CBIndex::c_0, + program, + all_device_cores, + a_single_tile_size, + num_tiles_per_cb, + a_data_format); // batch_mean + auto [b_cb, b_cb_handle] = create_cb( + tt::CBIndex::c_1, + program, + all_device_cores, + b_single_tile_size, + b_num_tiles_per_cb, + b_data_format); // batch_var + auto [c_cb, c_cb_handle] = create_cb( + tt::CBIndex::c_2, program, all_device_cores, c_single_tile_size, num_tiles_per_cb, c_data_format); // output + auto [d_cb, d_cb_handle] = create_cb( + tt::CBIndex::c_3, + program, + all_device_cores, + d_single_tile_size, + b_num_tiles_per_cb, + d_data_format); // old running mean + auto [e_cb, e_cb_handle] = create_cb( + tt::CBIndex::c_4, + program, + all_device_cores, + e_single_tile_size, + b_num_tiles_per_cb, + e_data_format); // old running var + auto [f_cb, f_cb_handle] = create_cb( + tt::CBIndex::c_5, + program, + all_device_cores, + b_single_tile_size, + b_num_tiles_per_cb, + b_data_format); // momentum + auto [one_cb, one_cb_handle] = create_cb( + tt::CBIndex::c_6, + program, + all_device_cores, + b_single_tile_size, + b_num_tiles_per_cb, + b_data_format); // to store 1 + auto [updated_m_cb, updated_m_cb_handle] = create_cb( + tt::CBIndex::c_27, + program, + all_device_cores, + d_single_tile_size, + b_num_tiles_per_cb, + d_data_format); // updated running mean + auto [updated_v_cb, updated_v_cb_handle] = create_cb( + tt::CBIndex::c_28, + program, + all_device_cores, + e_single_tile_size, + b_num_tiles_per_cb, + e_data_format); // updated running var + + // Intermediate buffers required for uodation of running stats + + auto [tmp1_cb, tmp1_cb_handle] = + create_cb(tt::CBIndex::c_21, program, all_device_cores, b_single_tile_size, b_num_tiles_per_cb, b_data_format); + + auto [tmp2_cb, tmp2_cb_handle] = + create_cb(tt::CBIndex::c_22, program, all_device_cores, b_single_tile_size, b_num_tiles_per_cb, b_data_format); + + auto [tmp3_cb, tmp3_cb_handle] = + create_cb(tt::CBIndex::c_23, program, all_device_cores, b_single_tile_size, b_num_tiles_per_cb, b_data_format); + + auto a_is_dram = static_cast(a.buffer()->buffer_type() == tt_metal::BufferType::DRAM); + auto b_is_dram = static_cast(b.buffer()->buffer_type() == tt_metal::BufferType::DRAM); + auto c_is_dram = static_cast(output.buffer()->buffer_type() == tt_metal::BufferType::DRAM); + const auto d_is_dram = running_mean_has_value and d->buffer()->buffer_type() == tt_metal::BufferType::DRAM; + const auto e_is_dram = running_var_has_value and e->buffer()->buffer_type() == tt_metal::BufferType::DRAM; + + // READER KERNEL + auto reader_kernel_id = tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_running_statistics.cpp", + all_device_cores, + tt_metal::ReaderDataMovementConfig({a_is_dram})); + + // WRITER KERNEL + auto writer_kernel_id = tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_running_statistics.cpp", + all_device_cores, + tt_metal::WriterDataMovementConfig({ + b_is_dram, + c_is_dram, + d_is_dram, + e_is_dram, + static_cast(running_mean_has_value), + static_cast(running_var_has_value), + })); + + // COMPUTE KERNEL + bool fp32_dest_acc_en = c_data_format == tt::DataFormat::UInt32 || c_data_format == tt::DataFormat::Int32 || + c_data_format == tt::DataFormat::Float32; + std::vector compute_kernel_args = { + static_cast(running_mean_has_value), static_cast(running_var_has_value)}; + auto compute_kernel_id = tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/running_statistics_kernel.cpp", + all_device_cores, + tt_metal::ComputeConfig{.fp32_dest_acc_en = fp32_dest_acc_en, .compile_args = compute_kernel_args}); + + auto set_runtime_args = [](Program& program, KernelHandle kernel_id, CoreCoord core, auto&& args) { + tt_metal::SetRuntimeArgs(program, kernel_id, core, args); + }; + + CMAKE_UNIQUE_NAMESPACE::set_or_update_runtime_arguments( + program, + reader_kernel_id, + writer_kernel_id, + compute_kernel_id, + compute_with_storage_grid_size, + operation_attributes, + tensor_args, + output, + set_runtime_args); + + return { + std::move(program), {reader_kernel_id, writer_kernel_id, compute_kernel_id, compute_with_storage_grid_size}}; +} + +void RunningStatistics::RunningStatisticsFactory::override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& output) { + auto update_args = [](Program& program, KernelHandle kernel_id, CoreCoord core, auto&& args) { + auto& all_args = GetRuntimeArgs(program, kernel_id); + auto& core_args = all_args.at(core.x).at(core.y); + std::copy(args.begin(), args.end(), core_args.data()); + }; + + CMAKE_UNIQUE_NAMESPACE::set_or_update_runtime_arguments( + cached_program.program, + cached_program.shared_variables.reader_kernel_id, + cached_program.shared_variables.writer_kernel_id, + cached_program.shared_variables.compute_kernel_id, + cached_program.shared_variables.compute_with_storage_grid_size, + operation_attributes, + tensor_args, + output, + update_args); +} + +} // namespace ttnn::operations::normalization