From 2ea97cccb69cca26312f04b92f49da6df1c47f5a Mon Sep 17 00:00:00 2001 From: mouliraj-mcw Date: Sat, 11 Jan 2025 15:39:00 +0000 Subject: [PATCH] #16186: update running statistics in batch norm --- .../unit_tests/operations/test_batch_norm.py | 46 +++++++++++-------- .../device/batch_norm_program_factory.cpp | 33 +++++++++++++ .../kernels/compute/batch_norm_kernel.cpp | 17 ++++++- .../kernels/dataflow/reader_batch_norm.cpp | 8 +++- 4 files changed, 83 insertions(+), 21 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_batch_norm.py b/tests/ttnn/unit_tests/operations/test_batch_norm.py index 459ba02ea17e..1aa782022404 100644 --- a/tests/ttnn/unit_tests/operations/test_batch_norm.py +++ b/tests/ttnn/unit_tests/operations/test_batch_norm.py @@ -15,38 +15,40 @@ @pytest.mark.parametrize( "input_shapes", [ - *(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3])), - torch.Size([4, 4, 32, 32]), - *(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3])), - torch.Size([4, 4, 23, 23]), - *(torch.Size([n, c, 64, 120]) for n, c in product([1, 2], [1, 2, 3])), - torch.Size([3, 1, 64, 120]), + # *(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3])), + # torch.Size([4, 4, 32, 32]), + # *(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3])), + # torch.Size([4, 4, 23, 23]), + # *(torch.Size([n, c, 64, 120]) for n, c in product([1, 2], [1, 2, 3])), + # torch.Size([3, 1, 64, 120]), torch.Size([3, 2, 64, 120]), ], ) @pytest.mark.parametrize( "training, check_mean, check_var", [ - # (True, True, True), - # (True, True, False), - # (True, False, True), - (True, False, False), - (False, False, False), # xfail case - (False, True, False), # xfail case - (False, False, True), # xfail case - (False, True, True), + (True, True, True), + # # (True, True, False), + # # (True, False, True), + # (True, False, False), + # (False, False, False), # xfail case + # (False, True, False), # xfail case + # (False, False, True), # xfail case + # (False, True, True), ], ) @pytest.mark.parametrize("weight", [True, False]) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05]) -@pytest.mark.parametrize("momentum", [0.1, 0.0, 2.3]) +@pytest.mark.parametrize("momentum", [0.1, 0.0]) def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias, eps, momentum, device): in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True) mean_data, mean_tensor = ( data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if (check_mean) else (None, None) ) var_data, var_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 20, device) if (check_var) else (None, None) + print("mean_tensor", mean_tensor) + print("var_tensor", var_tensor) weight_data, weight_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if weight else (None, None) bias_data, bias_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if bias else (None, None) @@ -65,9 +67,8 @@ def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias, ) tt_output = ttnn.to_torch(tt_output_tensor_on_device) - # tt_updated_mean = ttnn.to_torch(mean_tensor) - # tt_updated_var = ttnn.to_torch(var_tensor) - + tt_updated_mean = ttnn.to_torch(mean_tensor) + tt_updated_var = ttnn.to_torch(var_tensor) # ttnn.set_printoptions(profile="full") # print("TT result : ", tt_output, tt_output.shape) # torch.set_printoptions(precision=5, sci_mode=False) @@ -81,6 +82,14 @@ def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias, eps=eps, momentum=momentum, ) + batch_mean = in_data.mean(dim=(0, 2, 3)) + batch_var = in_data.var(dim=(0, 2, 3), unbiased=False) + print("Batch mean:", batch_mean) + print("Batch variance:", batch_var) + print("mean_data", mean_data) + print("tt_updated_mean", tt_updated_mean) + print("var_data", var_data) + print("tt_updated_var", tt_updated_var) # print("Torch result : ",torch_result) comp_pass = compare_results_batch_norm([tt_output], [torch_result]) # Check BN Result # if training : @@ -88,6 +97,7 @@ def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias, # comp_pass_1 = compare_results_batch_norm([tt_updated_mean], [mean_data.view(1, channels, 1, 1)]) # Check Updated running mean # comp_pass_2 = compare_results_batch_norm([tt_updated_var], [var_data.view(1, channels, 1, 1)]) # Check Updated running var # comp_pass = comp_pass and comp_pass_1 and comp_pass_2 + assert comp_pass diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp index 1fcc09197362..991878fb7d86 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp @@ -262,6 +262,39 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch b_num_tiles_per_cb, h_data_format); // updated running var + // Intermediate buffer + auto [one_cb, one_cb_handle] = create_cb( + tt::CBIndex::c_19, + program, + all_device_cores, + h_single_tile_size, + b_num_tiles_per_cb, + h_data_format); // to store 1 + + auto [tmp1_cb, tmp1_cb_handle] = create_cb( + tt::CBIndex::c_29, + program, + all_device_cores, + h_single_tile_size, + b_num_tiles_per_cb, + h_data_format); // to store tmp + + auto [tmp2_cb, tmp2_cb_handle] = create_cb( + tt::CBIndex::c_30, + program, + all_device_cores, + h_single_tile_size, + b_num_tiles_per_cb, + h_data_format); // to store tmp + + auto [tmp3_cb, tmp3_cb_handle] = create_cb( + tt::CBIndex::c_31, + program, + all_device_cores, + h_single_tile_size, + b_num_tiles_per_cb, + h_data_format); // to store tmp + auto a_is_dram = static_cast(a.buffer()->buffer_type() == tt_metal::BufferType::DRAM); auto b_is_dram = static_cast(b.buffer()->buffer_type() == tt_metal::BufferType::DRAM); auto c_is_dram = static_cast(output.buffer()->buffer_type() == tt_metal::BufferType::DRAM); diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp index e1a99a8ecaaa..dd7ca30397c7 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp @@ -30,7 +30,7 @@ ALWI void subtract_bcast_tiles( cb_push_back(cb_out, onetile); cb_pop_front(cb_other, onetile); } - cb_pop_front(cb_bcast, onetile); + // cb_pop_front(cb_bcast, onetile); } void MAIN { @@ -63,6 +63,10 @@ void MAIN { constexpr auto cb_updated_running_mean = tt::CBIndex::c_27; // updated running mean tensor constexpr auto cb_updated_running_var = tt::CBIndex::c_28; // updated running var tensor constexpr auto cb_momentum = tt::CBIndex::c_24; // momentum + constexpr auto cb_one = tt::CBIndex::c_19; // stores 1 + constexpr auto cb_tmp1 = tt::CBIndex::c_29; // tmp 1 + constexpr auto cb_tmp2 = tt::CBIndex::c_30; // tmp 2 + constexpr auto cb_tmp3 = tt::CBIndex::c_31; // tmp 3 auto cb_bcast = cb_batch_mean; auto cb_other = cb_input; @@ -102,7 +106,7 @@ void MAIN { pack_tile_with_dt(dst0, cb_den); tile_regs_release(); - cb_pop_front(cb_batch_var, 1); + // cb_pop_front(cb_batch_var, 1); cb_pop_front(cb_eps, 1); cb_push_back(cb_den, onetile); @@ -127,9 +131,18 @@ void MAIN { if constexpr (is_training_mode) { // updated running stats if constexpr (old_running_mean_has_value) { + sub_tiles_to_cb(cb_one, cb_momentum, cb_tmp1, tile_id, 0, 0, 0); // 1 - momentum + mul_tiles_to_cb(cb_momentum, cb_batch_mean, cb_tmp2, 0, tile_id, 0, 0); // momentum * running stats + mul_tiles_to_cb(cb_tmp1, cb_old_running_mean, cb_tmp3, 0, tile_id, 1, 0); // cb_tmp1 * batch stat + add_tiles_to_cb(cb_tmp2, cb_tmp3, cb_updated_running_mean, 0, 0, 1, 1); } if constexpr (old_running_var_has_value) { + sub_tiles_to_cb(cb_one, cb_momentum, cb_tmp1, tile_id, 0, 0, 0); // 1 - momentum + mul_tiles_to_cb(cb_momentum, cb_batch_var, cb_tmp2, 0, tile_id, 0, 0); // momentum * batch stat + mul_tiles_to_cb(cb_tmp1, cb_old_running_var, cb_tmp3, 0, tile_id, 0, 1); // cb_tmp1 * running stats + DPRINT << TSLICE(tt::CBIndex::c_26, 0, SliceRange::hw0_32_16()) << ENDL(); + add_tiles_to_cb(cb_tmp2, cb_tmp3, cb_updated_running_var, 0, 0, 1, 1); } } diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp index 22a019963129..53e72e3fd74a 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp @@ -37,13 +37,19 @@ void kernel_main() { uint32_t start_t = start_remaining % HtWt; constexpr auto cb_id_eps = tt::CBIndex::c_4; + constexpr auto cb_id_one = tt::CBIndex::c_19; cb_reserve_back(cb_id_eps, onetile); fill_with_val_bfloat16(cb_id_eps, eps); cb_push_back(cb_id_eps, onetile); constexpr auto cb_id_momentum = tt::CBIndex::c_24; - + union { + float f; + uint32_t u; + } scalar; + scalar.f = 1.0f; + fill_cb_with_value(cb_id_one, scalar.u); cb_reserve_back(cb_id_momentum, onetile); fill_with_val_bfloat16(cb_id_momentum, momentum); cb_push_back(cb_id_momentum, onetile);