diff --git a/tests/ttnn/unit_tests/operations/test_batch_norm.py b/tests/ttnn/unit_tests/operations/test_batch_norm.py index b655cf1cc1b..66d5d432d01 100644 --- a/tests/ttnn/unit_tests/operations/test_batch_norm.py +++ b/tests/ttnn/unit_tests/operations/test_batch_norm.py @@ -15,10 +15,8 @@ @pytest.mark.parametrize( "input_shapes", [ - *(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3])), - torch.Size([4, 4, 32, 32]), - *(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3]) if not (n == 3 and c == 3)), - torch.Size([4, 4, 23, 23]), + *(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])), + *(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3, 4])), *(torch.Size([n, c, 64, 120]) for n, c in product([1, 2], [1, 2, 3])), torch.Size([3, 1, 64, 120]), torch.Size([3, 2, 64, 120]), diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp index f953c8b9b13..5a7525b7a4e 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp @@ -9,99 +9,72 @@ namespace NAMESPACE { -ALWI void subtract_bcast_tiles( - uint32_t cb_bcast, uint32_t cb_other, uint32_t cb_out, uint32_t freq, uint32_t tile_start) { +ALWI void batchnorm_bcast_tiles( + uint32_t cb_bcast, + uint32_t cb_other, + uint32_t freq, + uint32_t tile_start, + uint32_t cb_batch_var, + uint32_t cb_eps, + uint32_t cb_den, + uint32_t cb_num, + uint32_t cb_weight, + uint32_t cb_bias, + uint32_t cb_tmp_1, + uint32_t cb_output_0, + uint32_t weight_has, + uint32_t bias_has) { constexpr uint32_t onetile = 1; + constexpr int dst0 = 0; + uint32_t weight_has_value = weight_has; + uint32_t bias_has_value = bias_has; + auto cb_affine_or_out = (weight_has_value || bias_has_value) ? cb_tmp_1 : cb_output_0; + auto cb_scaled_output = (bias_has_value) ? cb_tmp_1 : cb_output_0; cb_wait_front(cb_bcast, onetile); for (uint32_t j = tile_start; j < freq; ++j) { cb_wait_front(cb_other, onetile); - cb_reserve_back(cb_out, onetile); + cb_reserve_back(cb_num, onetile); tile_regs_acquire(); sub_tiles(cb_other, cb_bcast, 0, 0, 0); tile_regs_commit(); tile_regs_wait(); - pack_tile(0, cb_out); + pack_tile(0, cb_num); tile_regs_release(); - cb_push_back(cb_out, onetile); + cb_push_back(cb_num, onetile); cb_pop_front(cb_other, onetile); } cb_pop_front(cb_bcast, onetile); -} - -void MAIN { - uint32_t num_tiles = get_arg_val(0); - uint32_t tile_freq = get_arg_val(1); - uint32_t tile_start = get_arg_val(2); - constexpr uint32_t weight_has_value = get_compile_time_arg_val(0) == 1; - constexpr uint32_t bias_has_value = get_compile_time_arg_val(1) == 1; - - if (num_tiles == 0) { - return; - } - - constexpr auto cb_input = tt::CBIndex::c_0; // input - constexpr auto cb_batch_mean = tt::CBIndex::c_1; // batch_mean - constexpr auto cb_output_0 = - tt::CBIndex::c_2; // output -- > [(input - batch_mean)/(sqrt(batch_var + eps))] * weight - constexpr auto cb_batch_var = tt::CBIndex::c_3; // batch_var - constexpr auto cb_eps = tt::CBIndex::c_4; // eps - constexpr auto cb_den = tt::CBIndex::c_5; // 1/(sqrt(batch_var + eps)) - constexpr auto cb_num = tt::CBIndex::c_6; // input - batch_mean - constexpr auto cb_weight = tt::CBIndex::c_16; // weight tensor - constexpr auto cb_tmp_1 = tt::CBIndex::c_17; // (input - batch_mean)/(sqrt(batch_var + eps)) - constexpr auto cb_bias = tt::CBIndex::c_18; // bias tensor - auto cb_bcast = cb_batch_mean; - auto cb_other = cb_input; + // 1/(sqrt(batch_var + eps)) + cb_reserve_back(cb_den, onetile); + cb_wait_front(cb_batch_var, 1); + cb_wait_front(cb_eps, 1); - binary_op_init_common(cb_bcast, cb_other, cb_output_0); + tile_regs_acquire(); + add_tiles_init_with_dt(cb_batch_var, cb_eps); + add_tiles(cb_batch_var, cb_eps, 0, 0, dst0); + rsqrt_tile_init(); + rsqrt_tile(dst0); + tile_regs_commit(); - // input - batch_mean - sub_tiles_init(); - uint32_t complete_iterations = (num_tiles + tile_start) / tile_freq; - uint32_t remaining_iterations = (num_tiles + tile_start) % tile_freq; - for (uint32_t i = 0; i < complete_iterations; ++i, tile_start = 0) { - subtract_bcast_tiles(cb_bcast, cb_other, cb_num, tile_freq, tile_start); - } - if (remaining_iterations > 0) { - subtract_bcast_tiles(cb_bcast, cb_other, cb_num, remaining_iterations, tile_start); - } + tile_regs_wait(); + pack_tile_with_dt(dst0, cb_den); + tile_regs_release(); - constexpr uint32_t onetile = 1; - constexpr int dst0 = 0; - - constexpr auto cb_affine_or_out = (weight_has_value || bias_has_value) ? cb_tmp_1 : cb_output_0; - constexpr auto cb_scaled_output = (bias_has_value) ? cb_tmp_1 : cb_output_0; - for (uint32_t tile_id = 0; tile_id < num_tiles; ++tile_id) { - // 1/(sqrt(batch_var + eps)) - cb_reserve_back(cb_den, onetile); - cb_wait_front(cb_batch_var, 1); - cb_wait_front(cb_eps, 1); - - tile_regs_acquire(); - add_tiles_init_with_dt(cb_batch_var, cb_eps); - add_tiles(cb_batch_var, cb_eps, 0, 0, dst0); - rsqrt_tile_init(); - rsqrt_tile(dst0); - tile_regs_commit(); + cb_pop_front(cb_batch_var, 1); + cb_pop_front(cb_eps, 1); + cb_push_back(cb_den, onetile); - tile_regs_wait(); - pack_tile_with_dt(dst0, cb_den); - tile_regs_release(); - - cb_pop_front(cb_batch_var, 1); - cb_pop_front(cb_eps, 1); - cb_push_back(cb_den, onetile); - - // (input - batch_mean)/(sqrt(batch_var + eps)) = result - cb_reserve_back(cb_affine_or_out, onetile); + // (input - batch_mean)/(sqrt(batch_var + eps)) = result + cb_wait_front(cb_den, 1); + for (uint32_t j = tile_start; j < freq; ++j) { cb_wait_front(cb_num, 1); - cb_wait_front(cb_den, 1); + cb_reserve_back(cb_affine_or_out, onetile); tile_regs_acquire(); mul_tiles_init_with_dt(cb_num, cb_den); @@ -113,13 +86,15 @@ void MAIN { tile_regs_release(); cb_pop_front(cb_num, 1); - cb_pop_front(cb_den, 1); cb_push_back(cb_affine_or_out, onetile); + } + cb_pop_front(cb_den, 1); - if constexpr (weight_has_value) { // result = result * weight + if (weight_has_value) { // result = result * weight + cb_wait_front(cb_weight, 1); + for (uint32_t j = tile_start; j < freq; ++j) { cb_reserve_back(cb_scaled_output, onetile); cb_wait_front(cb_affine_or_out, 1); - cb_wait_front(cb_weight, 1); tile_regs_acquire(); mul_tiles_init_with_dt(cb_affine_or_out, cb_weight); @@ -131,13 +106,16 @@ void MAIN { tile_regs_release(); cb_pop_front(cb_affine_or_out, 1); - cb_pop_front(cb_weight, 1); + cb_push_back(cb_scaled_output, onetile); } - if constexpr (bias_has_value) { // result = result + bias + cb_pop_front(cb_weight, 1); + } + if (bias_has_value) { // result = result + bias + cb_wait_front(cb_bias, 1); + for (uint32_t j = tile_start; j < freq; ++j) { cb_reserve_back(cb_output_0, 1); cb_wait_front(cb_tmp_1, 1); - cb_wait_front(cb_bias, 1); tile_regs_acquire(); add_tiles_init_with_dt(cb_tmp_1, cb_bias); @@ -149,9 +127,79 @@ void MAIN { tile_regs_release(); cb_pop_front(cb_tmp_1, 1); - cb_pop_front(cb_bias, 1); cb_push_back(cb_output_0, 1); } + cb_pop_front(cb_bias, 1); } } + +void MAIN { + uint32_t num_tiles = get_arg_val(0); + uint32_t tile_freq = get_arg_val(1); + uint32_t tile_start = get_arg_val(2); + constexpr uint32_t weight_has_value = get_compile_time_arg_val(0) == 1; + constexpr uint32_t bias_has_value = get_compile_time_arg_val(1) == 1; + + if (num_tiles == 0) { + return; + } + + constexpr auto cb_input = tt::CBIndex::c_0; // input + constexpr auto cb_batch_mean = tt::CBIndex::c_1; // batch_mean + constexpr auto cb_output_0 = + tt::CBIndex::c_2; // output -- > [(input - batch_mean)/(sqrt(batch_var + eps))] * weight + constexpr auto cb_batch_var = tt::CBIndex::c_3; // batch_var + constexpr auto cb_eps = tt::CBIndex::c_4; // eps + constexpr auto cb_den = tt::CBIndex::c_5; // 1/(sqrt(batch_var + eps)) + constexpr auto cb_num = tt::CBIndex::c_6; // input - batch_mean + constexpr auto cb_weight = tt::CBIndex::c_16; // weight tensor + constexpr auto cb_tmp_1 = tt::CBIndex::c_17; // (input - batch_mean)/(sqrt(batch_var + eps)) + constexpr auto cb_bias = tt::CBIndex::c_18; // bias tensor + + auto cb_bcast = cb_batch_mean; + auto cb_other = cb_input; + + binary_op_init_common(cb_bcast, cb_other, cb_output_0); + + sub_tiles_init(); + uint32_t complete_iterations = (num_tiles + tile_start) / tile_freq; + uint32_t remaining_iterations = (num_tiles + tile_start) % tile_freq; + for (uint32_t i = 0; i < complete_iterations; ++i, tile_start = 0) { + batchnorm_bcast_tiles( + cb_bcast, + cb_other, + tile_freq, + tile_start, + cb_batch_var, + cb_eps, + cb_den, + cb_num, + cb_weight, + cb_bias, + cb_tmp_1, + cb_output_0, + weight_has_value, + bias_has_value); + } + if (remaining_iterations > 0) { + batchnorm_bcast_tiles( + cb_bcast, + cb_other, + remaining_iterations, + tile_start, + cb_batch_var, + cb_eps, + cb_den, + cb_num, + cb_weight, + cb_bias, + cb_tmp_1, + cb_output_0, + weight_has_value, + bias_has_value); + } + + constexpr uint32_t onetile = 1; + constexpr int dst0 = 0; +} } // namespace NAMESPACE