diff --git a/tests/ttnn/unit_tests/operations/test_batch_norm.py b/tests/ttnn/unit_tests/operations/test_batch_norm.py index efe6e390e9a6..9929f19f9819 100644 --- a/tests/ttnn/unit_tests/operations/test_batch_norm.py +++ b/tests/ttnn/unit_tests/operations/test_batch_norm.py @@ -15,12 +15,12 @@ @pytest.mark.parametrize( "input_shapes", [ - # *(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3])), - # torch.Size([4, 4, 32, 32]), - # *(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3])), - # torch.Size([4, 4, 23, 23]), - # *(torch.Size([n, c, 64, 120]) for n, c in product([1, 2], [1, 2, 3])), - # torch.Size([3, 1, 64, 120]), + *(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3])), + torch.Size([4, 4, 32, 32]), + *(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3]) if not (n == 3 and c == 3)), + torch.Size([4, 4, 23, 23]), + *(torch.Size([n, c, 64, 120]) for n, c in product([1, 2], [1, 2, 3])), + torch.Size([3, 1, 64, 120]), torch.Size([3, 2, 64, 120]), ], ) @@ -28,31 +28,29 @@ "training, check_mean, check_var", [ (True, True, True), - # # (True, True, False), - # # (True, False, True), - # (True, False, False), - # (False, False, False), # xfail case - # (False, True, False), # xfail case - # (False, False, True), # xfail case - # (False, True, True), + (True, True, False), + (True, False, True), + (True, False, False), + (False, False, False), # xfail case + (False, True, False), # xfail case + (False, False, True), # xfail case + (False, True, True), ], ) @pytest.mark.parametrize("weight", [True, False]) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05]) -@pytest.mark.parametrize("momentum", [0.1, 0.0]) +@pytest.mark.parametrize("momentum", [0.1, 0.0, 1.0, 2.3]) def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias, eps, momentum, device): in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True) mean_data, mean_tensor = ( data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if (check_mean) else (None, None) ) var_data, var_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 20, device) if (check_var) else (None, None) - print("mean_tensor", mean_tensor) - print("var_tensor", var_tensor) weight_data, weight_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if weight else (None, None) bias_data, bias_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if bias else (None, None) - if (not check_mean) or (not check_var): + if (not training) and ((not check_mean) or (not check_var)): pytest.xfail("running_mean and running_var must be defined in evaluation mode") tt_output_tensor_on_device = ttnn.batch_norm( @@ -66,12 +64,14 @@ def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias, bias=bias_tensor, ) tt_output = ttnn.to_torch(tt_output_tensor_on_device) + tt_updated_mean = None + tt_updated_var = None + if training: + if check_mean: + tt_updated_mean = ttnn.to_torch(mean_tensor) + if check_var: + tt_updated_var = ttnn.to_torch(var_tensor) - tt_updated_mean = ttnn.to_torch(mean_tensor) - tt_updated_var = ttnn.to_torch(var_tensor) - # ttnn.set_printoptions(profile="full") - # print("TT result : ", tt_output, tt_output.shape) - # torch.set_printoptions(precision=5, sci_mode=False) torch_result = torch.nn.functional.batch_norm( input=in_data, running_mean=mean_data, @@ -82,21 +82,28 @@ def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias, eps=eps, momentum=momentum, ) - batch_mean = in_data.mean(dim=(0, 2, 3)) - batch_var = in_data.var(dim=(0, 2, 3), unbiased=False) - print("Batch mean:", batch_mean) - print("Batch variance:", batch_var) - print("mean_data", mean_data) - print("tt_updated_mean", tt_updated_mean) - print("var_data", var_data) - print("tt_updated_var", tt_updated_var) - # print("Torch result : ",torch_result) comp_pass = compare_results_batch_norm([tt_output], [torch_result]) # Check BN Result - # if training : - # channels = input_shapes[1] - # comp_pass_1 = compare_results_batch_norm([tt_updated_mean], [mean_data.view(1, channels, 1, 1)]) # Check Updated running mean - # comp_pass_2 = compare_results_batch_norm([tt_updated_var], [var_data.view(1, channels, 1, 1)]) # Check Updated running var - # comp_pass = comp_pass and comp_pass_1 and comp_pass_2 + if training: + channels = input_shapes[1] + if check_mean: + comp_pass_1 = compare_results_batch_norm( + [tt_updated_mean], [mean_data.view(1, channels, 1, 1)] + ) # Check Updated running mean + else: + if tt_updated_mean is not None: + comp_pass_1 = True + else: + comp_pass_1 = False + if check_var: + comp_pass_2 = compare_results_batch_norm( + [tt_updated_var], [var_data.view(1, channels, 1, 1)] + ) # Check Updated running var + else: + if tt_updated_var is not None: + comp_pass_2 = True + else: + comp_pass_2 = False + comp_pass = comp_pass and comp_pass_1 and comp_pass_2 assert comp_pass diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp index e1ee461ddedd..c502b7aa8b3d 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp @@ -262,7 +262,7 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch b_num_tiles_per_cb, h_data_format); // updated running var - // Intermediate buffer + // Intermediate buffers required for uodation of running stats auto [one_cb, one_cb_handle] = create_cb( tt::CBIndex::c_19, program, @@ -271,29 +271,14 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch b_num_tiles_per_cb, h_data_format); // to store 1 - auto [tmp1_cb, tmp1_cb_handle] = create_cb( - tt::CBIndex::c_29, - program, - all_device_cores, - h_single_tile_size, - b_num_tiles_per_cb, - h_data_format); // to store tmp + auto [tmp1_cb, tmp1_cb_handle] = + create_cb(tt::CBIndex::c_29, program, all_device_cores, h_single_tile_size, b_num_tiles_per_cb, h_data_format); - auto [tmp2_cb, tmp2_cb_handle] = create_cb( - tt::CBIndex::c_30, - program, - all_device_cores, - h_single_tile_size, - b_num_tiles_per_cb, - h_data_format); // to store tmp + auto [tmp2_cb, tmp2_cb_handle] = + create_cb(tt::CBIndex::c_30, program, all_device_cores, h_single_tile_size, b_num_tiles_per_cb, h_data_format); - auto [tmp3_cb, tmp3_cb_handle] = create_cb( - tt::CBIndex::c_31, - program, - all_device_cores, - h_single_tile_size, - b_num_tiles_per_cb, - h_data_format); // to store tmp + auto [tmp3_cb, tmp3_cb_handle] = + create_cb(tt::CBIndex::c_31, program, all_device_cores, h_single_tile_size, b_num_tiles_per_cb, h_data_format); auto a_is_dram = static_cast(a.buffer()->buffer_type() == tt_metal::BufferType::DRAM); auto b_is_dram = static_cast(b.buffer()->buffer_type() == tt_metal::BufferType::DRAM); diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp index 3eed509754b9..bc397440da87 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp @@ -128,8 +128,9 @@ void MAIN { cb_pop_front(cb_den, 1); cb_push_back(cb_affine_or_out, onetile); + // Updation of running stats if constexpr (is_training_mode) { - // updated running stats + // updated_running_stat = (1 − momentum) × running_stat + momentum × batch_stat if constexpr (old_running_mean_has_value) { sub_tiles_to_cb(cb_one, cb_momentum, cb_tmp1, tile_id, 0, 0, 0); // 1 - momentum mul_tiles_to_cb(cb_momentum, cb_batch_mean, cb_tmp2, 0, tile_id, 0, 0); // momentum * running stats @@ -141,7 +142,6 @@ void MAIN { sub_tiles_to_cb(cb_one, cb_momentum, cb_tmp1, tile_id, 0, 0, 0); // 1 - momentum mul_tiles_to_cb(cb_momentum, cb_batch_var, cb_tmp2, 0, tile_id, 0, 0); // momentum * batch stat mul_tiles_to_cb(cb_tmp1, cb_old_running_var, cb_tmp3, 0, tile_id, 0, 1); // cb_tmp1 * running stats - DPRINT << TSLICE(tt::CBIndex::c_26, 0, SliceRange::hw0_32_16()) << ENDL(); add_tiles_to_cb(cb_tmp2, cb_tmp3, cb_updated_running_var, 0, 0, 1, 1); } } diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp index fb59ad447517..7dddf29c23a6 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp @@ -38,18 +38,19 @@ void kernel_main() { constexpr auto cb_id_eps = tt::CBIndex::c_4; constexpr auto cb_id_one = tt::CBIndex::c_19; + constexpr auto cb_id_momentum = tt::CBIndex::c_24; cb_reserve_back(cb_id_eps, onetile); fill_with_val_bfloat16(cb_id_eps, eps); cb_push_back(cb_id_eps, onetile); - constexpr auto cb_id_momentum = tt::CBIndex::c_24; union { float f; uint32_t u; } scalar; scalar.f = 1.0f; fill_cb_with_value(cb_id_one, scalar.u); + cb_reserve_back(cb_id_momentum, onetile); fill_with_val_bfloat16(cb_id_momentum, momentum); cb_push_back(cb_id_momentum, onetile); diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp index 0b6abebcf5ef..392b65d314b1 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp @@ -148,7 +148,7 @@ void kernel_main() { cb_push_back(cb_id_bias, onetile); } - // to read running stats value for updation + // Updation of running stats if constexpr (is_training_mode) { if constexpr (old_running_mean_has_value) { // read data