diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp index 1413ba24a5ed..5bcc1ec44861 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp @@ -16,16 +16,15 @@ Tensor BatchNorm::invoke( std::optional running_var, const bool training, const float eps, + const float momentum, std::optional weight, std::optional bias, std::optional output, const std::optional& memory_config) { - // TODO: Implementation for training mode is in progress - TT_FATAL((!training), "Support currently provided for inference mode only"); TT_FATAL( - (running_mean.has_value() && running_var.has_value()), + (running_mean.has_value() && running_var.has_value() && (!training)), "running_mean and running_var must be defined in evaluation mode"); return ttnn::prim::batch_norm( - input, running_mean.value(), running_var.value(), eps, weight, bias, output, memory_config); + input, running_mean.value(), running_var.value(), eps, momentum, training, weight, bias, output, memory_config); } } // namespace ttnn::operations::normalization diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.hpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.hpp index ae76785ef71c..5e17388b67d8 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.hpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.hpp @@ -15,6 +15,7 @@ struct BatchNorm { std::optional running_var = std::nullopt, const bool training = false, const float eps = 1e-05, + const float momentum = 0.1, std::optional weight = std::nullopt, std::optional bias = std::nullopt, std::optional output = std::nullopt, diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp index 5428b1aa4f35..c28165ef8c4e 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp @@ -23,6 +23,7 @@ void bind_batch_norm_operation(pybind11::module& module) { Keyword args: eps (float, optional): Epsilon value. Defaults to `1e-05`. + momentum (float, optional): Momentum value. Defaults to `0.1`. running_mean (ttnn.Tensor, optional): the running_mean of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`. running_var (ttnn.Tensor, optional): the running_var of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`. weight (ttnn.Tensor, optional): the weight or gamma value of shape `[1, C, 1, 1]`. Defaults to `None`. @@ -44,6 +45,7 @@ void bind_batch_norm_operation(pybind11::module& module) { py::arg("running_var") = std::nullopt, py::arg("training") = false, py::arg("eps") = 1e-05, + py::arg("momentum") = 0.1, py::arg("weight") = std::nullopt, py::arg("bias") = std::nullopt, py::arg("output") = std::nullopt, diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp index 3756cf561643..87caa2213397 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp @@ -123,11 +123,13 @@ std::tuple weight, std::optional bias, std::optional output, const std::optional& memory_config) { - operation_attributes_t operation_attributes{eps, memory_config.value_or(input.memory_config())}; + operation_attributes_t operation_attributes{eps, momentum, training, memory_config.value_or(input.memory_config())}; tensor_args_t tensor_args{input, batch_mean, batch_var, std::move(weight), std::move(bias), std::move(output)}; return {operation_attributes, tensor_args}; } diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.hpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.hpp index 454fac5783dc..985634f6dfdb 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.hpp @@ -12,6 +12,8 @@ namespace ttnn::operations::normalization { struct BatchNormOperation { struct operation_attributes_t { const float eps; + const float momentum; + const bool training; const MemoryConfig memory_config; DataType input_dtype; @@ -66,6 +68,8 @@ struct BatchNormOperation { const Tensor& batch_mean, const Tensor& batch_var, const float eps, + const float momentum, + const bool training, std::optional weight, std::optional bias, std::optional output, diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp index 5ff56fe761d6..9391a5c99447 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp @@ -31,6 +31,7 @@ void set_or_update_runtime_arguments( F handle_args) { const auto& [a, b, d, e, f, _] = tensor_args; const auto eps = operation_attributes.eps; + const auto momentum = operation_attributes.momentum; const bool weight_has_value = e.has_value(); const bool bias_has_value = f.has_value(); @@ -63,17 +64,21 @@ void set_or_update_runtime_arguments( } else if (core_group_2.contains(core)) { num_tiles_per_core = num_tiles_per_core_group_2; } else { - handle_args(program, reader_kernel_id, core, std::array{0}); + handle_args(program, reader_kernel_id, core, std::array{0}); handle_args(program, writer_kernel_id, core, std::array{0}); handle_args(program, compute_kernel_id, core, std::array{0}); continue; } uint32_t cHtWt = cHt * cWt; - class bfloat16 bfloat_scalar(eps); - uint32_t packed_scalar = pack_two_bfloat16_into_uint32({bfloat_scalar, bfloat_scalar}); + class bfloat16 bfloat_scalar_eps(eps); + uint32_t packed_scalar_eps = pack_two_bfloat16_into_uint32({bfloat_scalar_eps, bfloat_scalar_eps}); + class bfloat16 bfloat_scalar_momentum(momentum); + uint32_t packed_scalar_momentum = + pack_two_bfloat16_into_uint32({bfloat_scalar_momentum, bfloat_scalar_momentum}); std::array reader_runtime_args = { - packed_scalar, + packed_scalar_eps, + packed_scalar_momentum, a.buffer()->address(), start_tile_id, num_tiles_per_core, @@ -187,6 +192,13 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch tt::CBIndex::c_16, program, all_device_cores, e_single_tile_size, b_num_tiles_per_cb, e_data_format); // weight auto [f_cb, f_cb_handle] = create_cb( tt::CBIndex::c_18, program, all_device_cores, f_single_tile_size, b_num_tiles_per_cb, f_data_format); // bias + auto [momentum_cb, momentum_cb_handle] = create_cb( + tt::CBIndex::c_24, + program, + all_device_cores, + d_single_tile_size, + b_num_tiles_per_cb, + d_data_format); // momentum // Temporary buffers to store intermediate results auto [den_cb, den_cb_handle] = create_cb( @@ -232,13 +244,16 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch e_is_dram, f_is_dram, static_cast(weight_has_value), - static_cast(bias_has_value)})); + static_cast(bias_has_value), + static_cast(operation_attributes.training)})); // COMPUTE KERNEL bool fp32_dest_acc_en = c_data_format == tt::DataFormat::UInt32 || c_data_format == tt::DataFormat::Int32 || c_data_format == tt::DataFormat::Float32; std::vector compute_kernel_args = { - static_cast(weight_has_value), static_cast(bias_has_value)}; + static_cast(weight_has_value), + static_cast(bias_has_value), + static_cast(operation_attributes.training)}; auto compute_kernel_id = tt_metal::CreateKernel( program, "ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp", diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp index addf249190ca..b14d3ec5bfcf 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp @@ -39,6 +39,7 @@ void MAIN { uint32_t tile_start = get_arg_val(2); constexpr uint32_t weight_has_value = get_compile_time_arg_val(0) == 1; constexpr uint32_t bias_has_value = get_compile_time_arg_val(1) == 1; + constexpr uint32_t is_training_mode = get_compile_time_arg_val(2) == 1; if (num_tiles == 0) { return; @@ -116,6 +117,10 @@ void MAIN { cb_pop_front(cb_den, 1); cb_push_back(cb_affine_or_out, onetile); + if constexpr (is_training_mode) { + // update running stats here + } + if constexpr (weight_has_value) { // result = result * weight cb_reserve_back(cb_scaled_output, onetile); cb_wait_front(cb_affine_or_out, 1); diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp index 6ac8233dd24f..22a019963129 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp @@ -10,14 +10,15 @@ void kernel_main() { const auto eps = get_arg_val(0); - uint32_t src_addr = get_arg_val(1); // input tensor - uint32_t start_tile_id = get_arg_val(2); - uint32_t num_tiles = get_arg_val(3); - uint32_t HtWt = get_arg_val(4); - uint32_t n_stride = get_arg_val(5); - uint32_t c_stride = get_arg_val(6); - uint32_t N = get_arg_val(7); - uint32_t C = get_arg_val(8); + const auto momentum = get_arg_val(1); + uint32_t src_addr = get_arg_val(2); // input tensor + uint32_t start_tile_id = get_arg_val(3); + uint32_t num_tiles = get_arg_val(4); + uint32_t HtWt = get_arg_val(5); + uint32_t n_stride = get_arg_val(6); + uint32_t c_stride = get_arg_val(7); + uint32_t N = get_arg_val(8); + uint32_t C = get_arg_val(9); constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1; @@ -41,6 +42,12 @@ void kernel_main() { fill_with_val_bfloat16(cb_id_eps, eps); cb_push_back(cb_id_eps, onetile); + constexpr auto cb_id_momentum = tt::CBIndex::c_24; + + cb_reserve_back(cb_id_momentum, onetile); + fill_with_val_bfloat16(cb_id_momentum, momentum); + cb_push_back(cb_id_momentum, onetile); + // Input tile offset uint32_t tile_offset = start_n * n_stride + start_c * c_stride + start_t; diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp index 72bfbaeef7f7..7fdc32b5339d 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp @@ -70,6 +70,7 @@ void kernel_main() { constexpr bool weight_has_value = get_compile_time_arg_val(5) == 1; constexpr bool bias_has_value = get_compile_time_arg_val(6) == 1; + constexpr bool is_training_mode = get_compile_time_arg_val(7) == 1; uint32_t tiles_per_batch = HtWt * C; uint32_t start_n = start_tile_id / tiles_per_batch; @@ -118,6 +119,10 @@ void kernel_main() { cb_push_back(cb_id_bias, onetile); } + // to read running stats value for updation + if constexpr (is_training_mode) { + } + for (uint32_t t = start_t; t < HtWt && num_tiles_written < num_tiles; ++t, ++num_tiles_written) { // write a tile to dst cb_wait_front(cb_id_dst, onetile);