From ec8c29a3731e76698a04fc316d8b46db2518a573 Mon Sep 17 00:00:00 2001 From: VirdhatchaniKN Date: Fri, 10 Jan 2025 10:58:47 +0000 Subject: [PATCH] #0: Write updated running stats --- .../batch_norm/batch_norm_pybind.cpp | 6 +++--- .../device/batch_norm_program_factory.cpp | 14 ++++++++++++++ .../kernels/compute/batch_norm_kernel.cpp | 5 ++++- .../kernels/dataflow/writer_batch_norm.cpp | 18 ++++++++++++++++++ 4 files changed, 39 insertions(+), 4 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp index c28165ef8c4e..92973c9da986 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp @@ -14,7 +14,7 @@ void bind_batch_norm_operation(pybind11::module& module) { module, ttnn::batch_norm, R"doc( - Applies Spatial Batch Normalization over each channel on :attr:`input_tensor`. Inputs must be must be tilized and interleaved. Currently support is provided for inference mode only. + Applies Spatial Batch Normalization over each channel on :attr:`input_tensor`. Inputs must be must be tilized and interleaved. Args: @@ -24,8 +24,8 @@ void bind_batch_norm_operation(pybind11::module& module) { Keyword args: eps (float, optional): Epsilon value. Defaults to `1e-05`. momentum (float, optional): Momentum value. Defaults to `0.1`. - running_mean (ttnn.Tensor, optional): the running_mean of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`. - running_var (ttnn.Tensor, optional): the running_var of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`. + running_mean (ttnn.Tensor, optional): the running_mean of shape `[1, C, 1, 1]`, required in inference mode. When in training mode, this tensor is optional and the updated running mean value is stored in-place based on the inputs provided. Defaults to `None`. + running_var (ttnn.Tensor, optional): the running_var of shape `[1, C, 1, 1]`, required in inference mode. When in training mode, this tensor is optional and the updated running variance value is stored in-place based on the inputs provided. Defaults to `None`. weight (ttnn.Tensor, optional): the weight or gamma value of shape `[1, C, 1, 1]`. Defaults to `None`. bias (ttnn.Tensor, optional): the bias or beta value of shape `[1, C, 1, 1]`. Defaults to `None`. training (bool, optional): Selection between training mode and inference (evaluation) mode. Defaults to `False` (Inference mode). diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp index 07cc6840a1c1..8fac53823288 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp @@ -247,6 +247,20 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch a_data_format); // to store input - batch_mean auto [temp_1_cb, temp_1_cb_handle] = create_cb(tt::CBIndex::c_17, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format); + auto [updated_m_cb, updated_m_cb_handle] = create_cb( + tt::CBIndex::c_27, + program, + all_device_cores, + g_single_tile_size, + b_num_tiles_per_cb, + g_data_format); // updated running mean + auto [updated_v_cb, updated_v_cb_handle] = create_cb( + tt::CBIndex::c_28, + program, + all_device_cores, + h_single_tile_size, + b_num_tiles_per_cb, + h_data_format); // updated running var auto a_is_dram = static_cast(a.buffer()->buffer_type() == tt_metal::BufferType::DRAM); auto b_is_dram = static_cast(b.buffer()->buffer_type() == tt_metal::BufferType::DRAM); diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp index e653269561d0..57a93092bb70 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp @@ -60,6 +60,9 @@ void MAIN { constexpr auto cb_bias = tt::CBIndex::c_18; // bias tensor constexpr auto cb_old_running_mean = tt::CBIndex::c_25; // old running mean tensor constexpr auto cb_old_running_var = tt::CBIndex::c_26; // old running var tensor + constexpr auto cb_updated_running_mean = tt::CBIndex::c_27; // updated running mean tensor + constexpr auto cb_updated_running_var = tt::CBIndex::c_28; // updated running var tensor + constexpr auto cb_momentum = tt::CBIndex::c_24; // momentum auto cb_bcast = cb_batch_mean; auto cb_other = cb_input; @@ -122,7 +125,7 @@ void MAIN { cb_push_back(cb_affine_or_out, onetile); if constexpr (is_training_mode) { - // update running stats here + // updated running stats if constexpr (old_running_mean_has_value) { } diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp index a0f15101a886..0b6abebcf5ef 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp @@ -98,6 +98,8 @@ void kernel_main() { constexpr bool old_running_mean_has_value = get_compile_time_arg_val(10) == 1; constexpr bool old_running_var_has_value = get_compile_time_arg_val(11) == 1; + constexpr auto cb_id_updated_running_mean = tt::CBIndex::c_27; + constexpr auto cb_id_updated_running_var = tt::CBIndex::c_28; uint32_t tiles_per_batch = HtWt * C; uint32_t start_n = start_tile_id / tiles_per_batch; @@ -149,21 +151,37 @@ void kernel_main() { // to read running stats value for updation if constexpr (is_training_mode) { if constexpr (old_running_mean_has_value) { + // read data cb_reserve_back(cb_id_old_running_mean, onetile); uint32_t l1_old_running_mean_write_addr = get_write_ptr(cb_id_old_running_mean); noc_async_read_tile(tile_offset, old_running_mean, l1_old_running_mean_write_addr); noc_async_read_barrier(); fill_tile_with_first_element_bfloat16(cb_id_old_running_mean); cb_push_back(cb_id_old_running_mean, onetile); + + // write data + cb_wait_front(cb_id_updated_running_mean, onetile); + uint32_t l1_write_updated_mean_addr = get_read_ptr(cb_id_updated_running_mean); + noc_async_write_tile(tile_offset, old_running_mean, l1_write_updated_mean_addr); + noc_async_write_barrier(); + cb_pop_front(cb_id_updated_running_mean, onetile); } if constexpr (old_running_var_has_value) { + // read data cb_reserve_back(cb_id_old_running_var, onetile); uint32_t l1_old_running_var_write_addr = get_write_ptr(cb_id_old_running_var); noc_async_read_tile(tile_offset, old_running_var, l1_old_running_var_write_addr); noc_async_read_barrier(); fill_tile_with_first_element_bfloat16(cb_id_old_running_var); cb_push_back(cb_id_old_running_var, onetile); + + // write data + cb_wait_front(cb_id_updated_running_var, onetile); + uint32_t l1_write_updated_var_addr = get_read_ptr(cb_id_updated_running_var); + noc_async_write_tile(tile_offset, old_running_var, l1_write_updated_var_addr); + noc_async_write_barrier(); + cb_pop_front(cb_id_updated_running_var, onetile); } }