diff --git a/tests/ttnn/unit_tests/operations/test_batch_norm.py b/tests/ttnn/unit_tests/operations/test_batch_norm.py index 5b2287201a24..8e8b7bd6afb0 100644 --- a/tests/ttnn/unit_tests/operations/test_batch_norm.py +++ b/tests/ttnn/unit_tests/operations/test_batch_norm.py @@ -24,31 +24,50 @@ torch.Size([3, 2, 64, 120]), ], ) -@pytest.mark.parametrize("training", [False]) +@pytest.mark.parametrize( + "training, check_mean, check_var", + [ + # (True, True, True), + # (True, True, False), + # (True, False, True), + (True, False, False), + (False, False, False), # xfail case + (False, True, False), # xfail case + (False, False, True), # xfail case + (False, True, True), + ], +) @pytest.mark.parametrize("weight", [True, False]) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05]) -def test_batch_norm(input_shapes, training, weight, bias, eps, device): +@pytest.mark.parametrize("momentum", [0.1, 0.0, 2.3]) +def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias, eps, momentum, device): in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True) mean_data, mean_tensor = ( - data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if (not training) else (None, None) - ) - var_data, var_tensor = ( - data_gen_with_range_batch_norm(input_shapes, 4, 20, device) if (not training) else (None, None) + data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if (check_mean) else (None, None) ) + var_data, var_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 20, device) if (check_var) else (None, None) weight_data, weight_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if weight else (None, None) bias_data, bias_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if bias else (None, None) + if (not check_mean) or (not check_var): + pytest.xfail("running_mean and running_var must be defined in evaluation mode") + tt_output_tensor_on_device = ttnn.batch_norm( input_tensor, running_mean=mean_tensor, running_var=var_tensor, training=training, eps=eps, + momentum=momentum, weight=weight_tensor, bias=bias_tensor, ) tt_output = ttnn.to_torch(tt_output_tensor_on_device) + + # tt_updated_mean = ttnn.to_torch(mean_tensor) + # tt_updated_var = ttnn.to_torch(var_tensor) + # ttnn.set_printoptions(profile="full") # print("TT result : ", tt_output, tt_output.shape) # torch.set_printoptions(precision=5, sci_mode=False) @@ -60,9 +79,14 @@ def test_batch_norm(input_shapes, training, weight, bias, eps, device): bias=bias_data, training=training, eps=eps, + momentum=momentum, ) # print("Torch result : ",torch_result) - comp_pass = compare_results_batch_norm([tt_output], [torch_result]) + comp_pass = compare_results_batch_norm([tt_output], [torch_result]) # Check BN Result + # if training : + # comp_pass_1 = compare_results_batch_norm([tt_updated_mean], [mean_data]) # + # comp_pass_2 = compare_results_batch_norm([tt_updated_var], [var_data]) + # comp_pass = comp_pass and comp_pass_1 and comp_pass_2 assert comp_pass diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp index 5bcc1ec44861..dba53e2c5ac6 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp @@ -5,11 +5,27 @@ #include "batch_norm.hpp" #include "device/batch_norm_device_operation.hpp" +#include "ttnn/operations/moreh/moreh_mean/device/moreh_mean_device_operation.hpp" +#include "ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp" using namespace tt::tt_metal; namespace ttnn::operations::normalization { +inline Tensor mean_NHW(const Tensor& input_tensor, const std::optional& memory_config) { + auto output_memory_config = memory_config.value_or(input_tensor.memory_config()); + auto batch_mean = input_tensor; + ttnn::SmallVector dims = {0, 2, 3}; + std::sort(dims.begin(), dims.end()); + for (uint32_t i = dims.size() - 1; i > 0; i--) { + auto temp_output = ttnn::prim::moreh_mean( + batch_mean, dims[i], true, std::nullopt, std::nullopt, output_memory_config, std::nullopt); + batch_mean = temp_output; + } + return ttnn::prim::moreh_mean( + batch_mean, dims.front(), true, std::nullopt, std::nullopt, output_memory_config, std::nullopt); +} + Tensor BatchNorm::invoke( const Tensor& input, std::optional running_mean, @@ -21,10 +37,40 @@ Tensor BatchNorm::invoke( std::optional bias, std::optional output, const std::optional& memory_config) { + if (training) { + Tensor batch_mean = mean_NHW(input, memory_config); + Tensor mean_sq = mean_NHW(ttnn::square(input, memory_config), memory_config); + Tensor batch_var = + ttnn::subtract(mean_sq, ttnn::square(batch_mean, memory_config), std::nullopt, memory_config); + return ttnn::prim::batch_norm( + input, + batch_mean, + batch_var, + eps, + momentum, + training, + weight, + bias, + running_mean, + running_var, + output, + memory_config); + } TT_FATAL( - (running_mean.has_value() && running_var.has_value() && (!training)), + (running_mean.has_value() && running_var.has_value()), "running_mean and running_var must be defined in evaluation mode"); return ttnn::prim::batch_norm( - input, running_mean.value(), running_var.value(), eps, momentum, training, weight, bias, output, memory_config); + input, + running_mean.value(), + running_var.value(), + eps, + momentum, + training, + weight, + bias, + std::nullopt, + std::nullopt, + output, + memory_config); } } // namespace ttnn::operations::normalization diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp index c28165ef8c4e..92973c9da986 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp @@ -14,7 +14,7 @@ void bind_batch_norm_operation(pybind11::module& module) { module, ttnn::batch_norm, R"doc( - Applies Spatial Batch Normalization over each channel on :attr:`input_tensor`. Inputs must be must be tilized and interleaved. Currently support is provided for inference mode only. + Applies Spatial Batch Normalization over each channel on :attr:`input_tensor`. Inputs must be must be tilized and interleaved. Args: @@ -24,8 +24,8 @@ void bind_batch_norm_operation(pybind11::module& module) { Keyword args: eps (float, optional): Epsilon value. Defaults to `1e-05`. momentum (float, optional): Momentum value. Defaults to `0.1`. - running_mean (ttnn.Tensor, optional): the running_mean of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`. - running_var (ttnn.Tensor, optional): the running_var of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`. + running_mean (ttnn.Tensor, optional): the running_mean of shape `[1, C, 1, 1]`, required in inference mode. When in training mode, this tensor is optional and the updated running mean value is stored in-place based on the inputs provided. Defaults to `None`. + running_var (ttnn.Tensor, optional): the running_var of shape `[1, C, 1, 1]`, required in inference mode. When in training mode, this tensor is optional and the updated running variance value is stored in-place based on the inputs provided. Defaults to `None`. weight (ttnn.Tensor, optional): the weight or gamma value of shape `[1, C, 1, 1]`. Defaults to `None`. bias (ttnn.Tensor, optional): the bias or beta value of shape `[1, C, 1, 1]`. Defaults to `None`. training (bool, optional): Selection between training mode and inference (evaluation) mode. Defaults to `False` (Inference mode). diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp index 87caa2213397..9673493a5e52 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp @@ -10,7 +10,7 @@ namespace ttnn::operations::normalization { void BatchNormOperation::validate_tensors( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - const auto& [input, batch_mean, batch_var, weight, bias, output] = tensor_args; + const auto& [input, batch_mean, batch_var, weight, bias, running_mean, running_var, output] = tensor_args; check_tensor(input, "batch_norm", "input"); check_tensor(batch_mean, "batch_norm", "batch_mean"); @@ -18,6 +18,8 @@ void BatchNormOperation::validate_tensors( check_tensor(weight, "batch_norm", "weight"); check_tensor(bias, "batch_norm", "bias"); check_tensor(output, "batch_norm", "output"); + check_tensor(running_mean, "batch_norm", "running_mean"); + check_tensor(running_var, "batch_norm", "running_var"); // input (N, C, H, W) auto C = input.get_logical_shape()[1]; @@ -45,6 +47,26 @@ void BatchNormOperation::validate_tensors( TT_FATAL(bias.value().get_logical_shape()[1] == C, "bias_shape[1] must be the same as input's channel size."); TT_FATAL(bias.value().get_logical_shape()[1] == C, "bias_shape[1] must be the same as input's channel size."); } + + // running_mean (1, C, 1, 1) + if (running_mean.has_value()) { + TT_FATAL( + running_mean.value().get_logical_shape()[1] == C, + "running_mean_shape[1] must be the same as input's channel size."); + TT_FATAL( + running_mean.value().get_logical_shape()[1] == C, + "running_mean_shape[1] must be the same as input's channel size."); + } + + // running_var (1, C, 1, 1) + if (running_var.has_value()) { + TT_FATAL( + running_var.value().get_logical_shape()[1] == C, + "running_var_shape[1] must be the same as input's channel size."); + TT_FATAL( + running_var.value().get_logical_shape()[1] == C, + "running_var_shape[1] must be the same as input's channel size."); + } } BatchNormOperation::program_factory_t BatchNormOperation::select_program_factory( @@ -54,7 +76,8 @@ BatchNormOperation::program_factory_t BatchNormOperation::select_program_factory void BatchNormOperation::validate_on_program_cache_miss( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - const auto& [input, batch_mean, batch_var, weight, bias, output] = tensor_args; + + const auto& [input, batch_mean, batch_var, weight, bias, running_mean, running_var, output] = tensor_args; TT_FATAL(input.get_layout() == Layout::TILE, "Input tensor must be must be tilized"); TT_FATAL( @@ -87,6 +110,20 @@ void BatchNormOperation::validate_on_program_cache_miss( "bias tensor must be interleaved"); } + if (running_mean.has_value()) { + TT_FATAL(running_mean.value().get_layout() == Layout::TILE, "running_mean tensor must be tilized"); + TT_FATAL( + running_mean.value().memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, + "running_mean tensor must be interleaved"); + } + + if (running_var.has_value()) { + TT_FATAL(running_var.value().get_layout() == Layout::TILE, "running_var tensor must be tilized"); + TT_FATAL( + running_var.value().memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, + "running_var tensor must be interleaved"); + } + validate_tensors(operation_attributes, tensor_args); }; @@ -127,10 +164,20 @@ std::tuple weight, std::optional bias, + std::optional running_mean, + std::optional running_var, std::optional output, const std::optional& memory_config) { operation_attributes_t operation_attributes{eps, momentum, training, memory_config.value_or(input.memory_config())}; - tensor_args_t tensor_args{input, batch_mean, batch_var, std::move(weight), std::move(bias), std::move(output)}; + tensor_args_t tensor_args{ + input, + batch_mean, + batch_var, + std::move(weight), + std::move(bias), + std::move(running_mean), + std::move(running_var), + std::move(output)}; return {operation_attributes, tensor_args}; } } // namespace ttnn::operations::normalization diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.hpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.hpp index 985634f6dfdb..d9d848c7564a 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.hpp @@ -27,6 +27,8 @@ struct BatchNormOperation { const Tensor& batch_var; std::optional weight; std::optional bias; + std::optional running_mean; + std::optional running_var; std::optional output; }; @@ -72,6 +74,8 @@ struct BatchNormOperation { const bool training, std::optional weight, std::optional bias, + std::optional running_mean, + std::optional running_var, std::optional output, const std::optional& memory_config); }; diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp index d9f57686260f..8fac53823288 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp @@ -29,12 +29,15 @@ void set_or_update_runtime_arguments( const BatchNormOperation::tensor_args_t& tensor_args, BatchNormOperation::tensor_return_value_t& c, F handle_args) { - const auto& [a, b, d, e, f, _] = tensor_args; + const auto& [a, b, d, e, f, g, h, _] = tensor_args; const auto eps = operation_attributes.eps; const auto momentum = operation_attributes.momentum; const bool weight_has_value = e.has_value(); const bool bias_has_value = f.has_value(); + const bool running_mean_has_value = g.has_value(); + const bool running_var_has_value = h.has_value(); + const bool is_training_mode = operation_attributes.training; const auto ashape = a.padded_shape(); const auto bshape = b.padded_shape(); @@ -65,7 +68,7 @@ void set_or_update_runtime_arguments( num_tiles_per_core = num_tiles_per_core_group_2; } else { handle_args(program, reader_kernel_id, core, std::array{0}); - handle_args(program, writer_kernel_id, core, std::array{0}); + handle_args(program, writer_kernel_id, core, std::array{0}); handle_args(program, compute_kernel_id, core, std::array{0}); continue; } @@ -93,11 +96,15 @@ void set_or_update_runtime_arguments( const auto weight_addr = weight_has_value ? e->buffer()->address() : 0; const auto bias_addr = bias_has_value ? f->buffer()->address() : 0; + const auto running_mean_addr = is_training_mode and running_mean_has_value ? g->buffer()->address() : 0; + const auto running_var_addr = is_training_mode and running_var_has_value ? h->buffer()->address() : 0; std::array writer_runtime_args = { b.buffer()->address(), // batch mean d.buffer()->address(), // batch var weight_addr, // weight bias_addr, // bias + running_mean_addr, // old running mean + running_var_addr, // old running var c.buffer()->address(), // output start_tile_id, num_tiles_per_core, @@ -131,7 +138,7 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch using namespace tt; using namespace tt::tt_metal; - const auto& [a, b, d, e, f, _] = tensor_args; + const auto& [a, b, d, e, f, g, h, _] = tensor_args; auto program = CreateProgram(); @@ -139,6 +146,9 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch const bool weight_has_value = e.has_value(); const bool bias_has_value = f.has_value(); + const bool running_mean_has_value = g.has_value(); + const bool running_var_has_value = h.has_value(); + const bool is_training_mode = operation_attributes.training; auto a_data_format = datatype_to_dataformat_converter(a.get_dtype()); auto b_data_format = datatype_to_dataformat_converter(b.get_dtype()); @@ -146,6 +156,10 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch auto d_data_format = datatype_to_dataformat_converter(d.get_dtype()); auto e_data_format = weight_has_value ? datatype_to_dataformat_converter(e->get_dtype()) : DataFormat::Float16_b; auto f_data_format = bias_has_value ? datatype_to_dataformat_converter(f->get_dtype()) : DataFormat::Float16_b; + auto g_data_format = is_training_mode and running_mean_has_value ? datatype_to_dataformat_converter(g->get_dtype()) + : DataFormat::Float16_b; + auto h_data_format = is_training_mode and running_var_has_value ? datatype_to_dataformat_converter(h->get_dtype()) + : DataFormat::Float16_b; uint32_t a_single_tile_size = tt_metal::detail::TileSize(a_data_format); uint32_t b_single_tile_size = tt_metal::detail::TileSize(b_data_format); @@ -153,6 +167,8 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch uint32_t d_single_tile_size = tt_metal::detail::TileSize(d_data_format); uint32_t e_single_tile_size = tt_metal::detail::TileSize(e_data_format); uint32_t f_single_tile_size = tt_metal::detail::TileSize(f_data_format); + uint32_t g_single_tile_size = tt_metal::detail::TileSize(g_data_format); + uint32_t h_single_tile_size = tt_metal::detail::TileSize(h_data_format); uint32_t num_output_tiles = output.volume() / output.tensor_spec().tile().get_tile_hw(); @@ -199,6 +215,20 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch d_single_tile_size, b_num_tiles_per_cb, d_data_format); // momentum + auto [g_cb, g_cb_handle] = create_cb( + tt::CBIndex::c_25, + program, + all_device_cores, + g_single_tile_size, + b_num_tiles_per_cb, + g_data_format); // old running mean + auto [h_cb, h_cb_handle] = create_cb( + tt::CBIndex::c_26, + program, + all_device_cores, + h_single_tile_size, + b_num_tiles_per_cb, + h_data_format); // old running var // Temporary buffers to store intermediate results auto [den_cb, den_cb_handle] = create_cb( @@ -217,6 +247,20 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch a_data_format); // to store input - batch_mean auto [temp_1_cb, temp_1_cb_handle] = create_cb(tt::CBIndex::c_17, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format); + auto [updated_m_cb, updated_m_cb_handle] = create_cb( + tt::CBIndex::c_27, + program, + all_device_cores, + g_single_tile_size, + b_num_tiles_per_cb, + g_data_format); // updated running mean + auto [updated_v_cb, updated_v_cb_handle] = create_cb( + tt::CBIndex::c_28, + program, + all_device_cores, + h_single_tile_size, + b_num_tiles_per_cb, + h_data_format); // updated running var auto a_is_dram = static_cast(a.buffer()->buffer_type() == tt_metal::BufferType::DRAM); auto b_is_dram = static_cast(b.buffer()->buffer_type() == tt_metal::BufferType::DRAM); @@ -224,6 +268,10 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch auto d_is_dram = static_cast(d.buffer()->buffer_type() == tt_metal::BufferType::DRAM); const auto e_is_dram = weight_has_value and e->buffer()->buffer_type() == tt_metal::BufferType::DRAM; const auto f_is_dram = bias_has_value and f->buffer()->buffer_type() == tt_metal::BufferType::DRAM; + const auto g_is_dram = + is_training_mode and running_mean_has_value and g->buffer()->buffer_type() == tt_metal::BufferType::DRAM; + const auto h_is_dram = + is_training_mode and running_var_has_value and h->buffer()->buffer_type() == tt_metal::BufferType::DRAM; // READER KERNEL auto reader_kernel_id = tt_metal::CreateKernel( @@ -237,15 +285,20 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch program, "ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp", all_device_cores, - tt_metal::WriterDataMovementConfig( - {b_is_dram, - c_is_dram, - d_is_dram, - e_is_dram, - f_is_dram, - static_cast(weight_has_value), - static_cast(bias_has_value), - static_cast(operation_attributes.training)})); + tt_metal::WriterDataMovementConfig({ + b_is_dram, + c_is_dram, + d_is_dram, + e_is_dram, + f_is_dram, + static_cast(weight_has_value), + static_cast(bias_has_value), + static_cast(operation_attributes.training), + g_is_dram, + h_is_dram, + static_cast(running_mean_has_value), + static_cast(running_var_has_value), + })); // COMPUTE KERNEL bool fp32_dest_acc_en = c_data_format == tt::DataFormat::UInt32 || c_data_format == tt::DataFormat::Int32 || @@ -253,7 +306,9 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch std::vector compute_kernel_args = { static_cast(weight_has_value), static_cast(bias_has_value), - static_cast(operation_attributes.training)}; + static_cast(operation_attributes.training), + static_cast(running_mean_has_value), + static_cast(running_var_has_value)}; auto compute_kernel_id = tt_metal::CreateKernel( program, "ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp", diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp index 015f8b806d9a..57a93092bb70 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp @@ -40,6 +40,8 @@ void MAIN { constexpr uint32_t weight_has_value = get_compile_time_arg_val(0) == 1; constexpr uint32_t bias_has_value = get_compile_time_arg_val(1) == 1; constexpr uint32_t is_training_mode = get_compile_time_arg_val(2) == 1; + constexpr uint32_t old_running_mean_has_value = get_compile_time_arg_val(3) == 1; + constexpr uint32_t old_running_var_has_value = get_compile_time_arg_val(4) == 1; if (num_tiles == 0) { return; @@ -56,6 +58,11 @@ void MAIN { constexpr auto cb_weight = tt::CBIndex::c_16; // weight tensor constexpr auto cb_tmp_1 = tt::CBIndex::c_17; // (input - batch_mean)/(sqrt(batch_var + eps)) constexpr auto cb_bias = tt::CBIndex::c_18; // bias tensor + constexpr auto cb_old_running_mean = tt::CBIndex::c_25; // old running mean tensor + constexpr auto cb_old_running_var = tt::CBIndex::c_26; // old running var tensor + constexpr auto cb_updated_running_mean = tt::CBIndex::c_27; // updated running mean tensor + constexpr auto cb_updated_running_var = tt::CBIndex::c_28; // updated running var tensor + constexpr auto cb_momentum = tt::CBIndex::c_24; // momentum auto cb_bcast = cb_batch_mean; auto cb_other = cb_input; @@ -118,7 +125,12 @@ void MAIN { cb_push_back(cb_affine_or_out, onetile); if constexpr (is_training_mode) { - // update running stats here + // updated running stats + if constexpr (old_running_mean_has_value) { + } + + if constexpr (old_running_var_has_value) { + } } if constexpr (weight_has_value) { // result = result * weight diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp index d0eaa41d6a41..0b6abebcf5ef 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp @@ -12,14 +12,16 @@ void kernel_main() { uint32_t batch_var_addr = get_arg_val(1); // batch_var uint32_t weight_addr = get_arg_val(2); // weight uint32_t bias_addr = get_arg_val(3); // bias - uint32_t dst_addr = get_arg_val(4); // output - uint32_t start_tile_id = get_arg_val(5); - uint32_t num_tiles = get_arg_val(6); - uint32_t HtWt = get_arg_val(7); - uint32_t n_stride = get_arg_val(8); - uint32_t c_stride = get_arg_val(9); - uint32_t N = get_arg_val(10); - uint32_t C = get_arg_val(11); + uint32_t old_running_mean_addr = get_arg_val(4); // old running_mean + uint32_t old_running_var_addr = get_arg_val(5); // ols running_var + uint32_t dst_addr = get_arg_val(6); // output + uint32_t start_tile_id = get_arg_val(7); + uint32_t num_tiles = get_arg_val(8); + uint32_t HtWt = get_arg_val(9); + uint32_t n_stride = get_arg_val(10); + uint32_t c_stride = get_arg_val(11); + uint32_t N = get_arg_val(12); + uint32_t C = get_arg_val(13); constexpr uint32_t onetile = 1; @@ -72,6 +74,33 @@ void kernel_main() { constexpr bool bias_has_value = get_compile_time_arg_val(6) == 1; constexpr bool is_training_mode = get_compile_time_arg_val(7) == 1; + // old running mean + constexpr auto cb_id_old_running_mean = tt::CBIndex::c_25; + constexpr bool old_running_mean_is_dram = get_compile_time_arg_val(8) == 1; + const uint32_t old_running_mean_tile_bytes = get_tile_size(cb_id_old_running_mean); + const DataFormat old_running_mean_data_format = get_dataformat(cb_id_old_running_mean); + + const InterleavedAddrGenFast old_running_mean = { + .bank_base_address = old_running_mean_addr, + .page_size = old_running_mean_tile_bytes, + .data_format = old_running_mean_data_format}; + + // old running var + constexpr auto cb_id_old_running_var = tt::CBIndex::c_26; + constexpr bool old_running_var_is_dram = get_compile_time_arg_val(9) == 1; + const uint32_t old_running_var_tile_bytes = get_tile_size(cb_id_old_running_var); + const DataFormat old_running_var_data_format = get_dataformat(cb_id_old_running_var); + + const InterleavedAddrGenFast old_running_var = { + .bank_base_address = old_running_var_addr, + .page_size = old_running_var_tile_bytes, + .data_format = old_running_var_data_format}; + + constexpr bool old_running_mean_has_value = get_compile_time_arg_val(10) == 1; + constexpr bool old_running_var_has_value = get_compile_time_arg_val(11) == 1; + constexpr auto cb_id_updated_running_mean = tt::CBIndex::c_27; + constexpr auto cb_id_updated_running_var = tt::CBIndex::c_28; + uint32_t tiles_per_batch = HtWt * C; uint32_t start_n = start_tile_id / tiles_per_batch; uint32_t start_remaining = start_tile_id % tiles_per_batch; @@ -121,6 +150,39 @@ void kernel_main() { // to read running stats value for updation if constexpr (is_training_mode) { + if constexpr (old_running_mean_has_value) { + // read data + cb_reserve_back(cb_id_old_running_mean, onetile); + uint32_t l1_old_running_mean_write_addr = get_write_ptr(cb_id_old_running_mean); + noc_async_read_tile(tile_offset, old_running_mean, l1_old_running_mean_write_addr); + noc_async_read_barrier(); + fill_tile_with_first_element_bfloat16(cb_id_old_running_mean); + cb_push_back(cb_id_old_running_mean, onetile); + + // write data + cb_wait_front(cb_id_updated_running_mean, onetile); + uint32_t l1_write_updated_mean_addr = get_read_ptr(cb_id_updated_running_mean); + noc_async_write_tile(tile_offset, old_running_mean, l1_write_updated_mean_addr); + noc_async_write_barrier(); + cb_pop_front(cb_id_updated_running_mean, onetile); + } + + if constexpr (old_running_var_has_value) { + // read data + cb_reserve_back(cb_id_old_running_var, onetile); + uint32_t l1_old_running_var_write_addr = get_write_ptr(cb_id_old_running_var); + noc_async_read_tile(tile_offset, old_running_var, l1_old_running_var_write_addr); + noc_async_read_barrier(); + fill_tile_with_first_element_bfloat16(cb_id_old_running_var); + cb_push_back(cb_id_old_running_var, onetile); + + // write data + cb_wait_front(cb_id_updated_running_var, onetile); + uint32_t l1_write_updated_var_addr = get_read_ptr(cb_id_updated_running_var); + noc_async_write_tile(tile_offset, old_running_var, l1_write_updated_var_addr); + noc_async_write_barrier(); + cb_pop_front(cb_id_updated_running_var, onetile); + } } for (uint32_t t = start_t; t < HtWt && num_tiles_written < num_tiles; ++t, ++num_tiles_written) {