From 084b7b674ab4387511f734c4a743ebeb1007bf1c Mon Sep 17 00:00:00 2001 From: VirdhatchaniKN Date: Fri, 24 Jan 2025 03:57:04 +0000 Subject: [PATCH] #12253: Update files --- .../device/batch_norm_program_factory.cpp | 69 ++++++++++--------- .../running_statistics_program_factory.cpp | 67 +++++++++--------- 2 files changed, 73 insertions(+), 63 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp index dacc158dfa63..c640a45e00d8 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp @@ -29,18 +29,18 @@ void set_or_update_runtime_arguments( const BatchNormOperation::tensor_args_t& tensor_args, BatchNormOperation::tensor_return_value_t& c, F handle_args) { - const auto& [a, b, d, e, f, _] = tensor_args; + const auto& [input_tensor, batch_mean_tensor, batch_var_tensor, weight_tensor, bias_tensor, _] = tensor_args; const auto eps = operation_attributes.eps; - const bool weight_has_value = e.has_value(); - const bool bias_has_value = f.has_value(); + const bool weight_has_value = weight_tensor.has_value(); + const bool bias_has_value = bias_tensor.has_value(); - const auto ashape = a.padded_shape(); - const auto bshape = b.padded_shape(); + const auto ashape = input_tensor.padded_shape(); + const auto bshape = batch_mean_tensor.padded_shape(); const auto cshape = c.padded_shape(); - const auto [aN, aC, aHt, aWt] = extract_shape_dims(a); - const auto [bN, bC, bHt, bWt] = extract_shape_dims(b); + const auto [aN, aC, aHt, aWt] = extract_shape_dims(input_tensor); + const auto [bN, bC, bHt, bWt] = extract_shape_dims(batch_mean_tensor); const auto [cN, cC, cHt, cWt] = extract_shape_dims(c); uint32_t num_output_tiles = c.volume() / c.tensor_spec().tile().get_tile_hw(); @@ -54,6 +54,9 @@ void set_or_update_runtime_arguments( tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_output_tiles, row_major); auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major); + constexpr size_t num_reader_args = 11; + constexpr size_t num_writer_args = 14; + constexpr size_t num_kernel_args = 3; for (uint32_t i = 0, start_tile_id = 0; i < num_cores_total; i++) { const auto& core = cores[i]; @@ -63,9 +66,9 @@ void set_or_update_runtime_arguments( } else if (core_group_2.contains(core)) { num_tiles_per_core = num_tiles_per_core_group_2; } else { - handle_args(program, reader_kernel_id, core, std::array{0}); - handle_args(program, writer_kernel_id, core, std::array{0}); - handle_args(program, compute_kernel_id, core, std::array{0}); + handle_args(program, reader_kernel_id, core, std::array{0}); + handle_args(program, writer_kernel_id, core, std::array{0}); + handle_args(program, compute_kernel_id, core, std::array{0}); continue; } @@ -74,7 +77,7 @@ void set_or_update_runtime_arguments( uint32_t packed_scalar_eps = pack_two_bfloat16_into_uint32({bfloat_scalar_eps, bfloat_scalar_eps}); std::array reader_runtime_args = { packed_scalar_eps, - a.buffer()->address(), + input_tensor.buffer()->address(), start_tile_id, num_tiles_per_core, cHtWt, @@ -86,14 +89,14 @@ void set_or_update_runtime_arguments( cWt}; handle_args(program, reader_kernel_id, core, reader_runtime_args); - const auto weight_addr = weight_has_value ? e->buffer()->address() : 0; - const auto bias_addr = bias_has_value ? f->buffer()->address() : 0; + const auto weight_addr = weight_has_value ? weight_tensor->buffer()->address() : 0; + const auto bias_addr = bias_has_value ? bias_tensor->buffer()->address() : 0; std::array writer_runtime_args = { - b.buffer()->address(), // batch mean - d.buffer()->address(), // batch var - weight_addr, // weight - bias_addr, // bias - c.buffer()->address(), // output + batch_mean_tensor.buffer()->address(), // batch mean + batch_var_tensor.buffer()->address(), // batch var + weight_addr, // weight + bias_addr, // bias + c.buffer()->address(), // output start_tile_id, num_tiles_per_core, cHtWt, @@ -126,21 +129,23 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch using namespace tt; using namespace tt::tt_metal; - const auto& [a, b, d, e, f, _] = tensor_args; + const auto& [input_tensor, batch_mean_tensor, batch_var_tensor, weight_tensor, bias_tensor, _] = tensor_args; auto program = CreateProgram(); - auto* device = a.device(); + auto* device = input_tensor.device(); - const bool weight_has_value = e.has_value(); - const bool bias_has_value = f.has_value(); + const bool weight_has_value = weight_tensor.has_value(); + const bool bias_has_value = bias_tensor.has_value(); - auto a_data_format = datatype_to_dataformat_converter(a.get_dtype()); - auto b_data_format = datatype_to_dataformat_converter(b.get_dtype()); + auto a_data_format = datatype_to_dataformat_converter(input_tensor.get_dtype()); + auto b_data_format = datatype_to_dataformat_converter(batch_mean_tensor.get_dtype()); auto c_data_format = datatype_to_dataformat_converter(output.get_dtype()); - auto d_data_format = datatype_to_dataformat_converter(d.get_dtype()); - auto e_data_format = weight_has_value ? datatype_to_dataformat_converter(e->get_dtype()) : DataFormat::Float16_b; - auto f_data_format = bias_has_value ? datatype_to_dataformat_converter(f->get_dtype()) : DataFormat::Float16_b; + auto d_data_format = datatype_to_dataformat_converter(batch_var_tensor.get_dtype()); + auto e_data_format = + weight_has_value ? datatype_to_dataformat_converter(weight_tensor->get_dtype()) : DataFormat::Float16_b; + auto f_data_format = + bias_has_value ? datatype_to_dataformat_converter(bias_tensor->get_dtype()) : DataFormat::Float16_b; uint32_t a_single_tile_size = tt_metal::detail::TileSize(a_data_format); uint32_t b_single_tile_size = tt_metal::detail::TileSize(b_data_format); @@ -206,12 +211,12 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch auto [temp_1_cb, temp_1_cb_handle] = create_cb(tt::CBIndex::c_17, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format); - auto a_is_dram = static_cast(a.buffer()->buffer_type() == tt_metal::BufferType::DRAM); - auto b_is_dram = static_cast(b.buffer()->buffer_type() == tt_metal::BufferType::DRAM); + auto a_is_dram = static_cast(input_tensor.buffer()->buffer_type() == tt_metal::BufferType::DRAM); + auto b_is_dram = static_cast(batch_mean_tensor.buffer()->buffer_type() == tt_metal::BufferType::DRAM); auto c_is_dram = static_cast(output.buffer()->buffer_type() == tt_metal::BufferType::DRAM); - auto d_is_dram = static_cast(d.buffer()->buffer_type() == tt_metal::BufferType::DRAM); - const auto e_is_dram = weight_has_value and e->buffer()->buffer_type() == tt_metal::BufferType::DRAM; - const auto f_is_dram = bias_has_value and f->buffer()->buffer_type() == tt_metal::BufferType::DRAM; + auto d_is_dram = static_cast(batch_var_tensor.buffer()->buffer_type() == tt_metal::BufferType::DRAM); + const auto e_is_dram = weight_has_value and weight_tensor->buffer()->buffer_type() == tt_metal::BufferType::DRAM; + const auto f_is_dram = bias_has_value and bias_tensor->buffer()->buffer_type() == tt_metal::BufferType::DRAM; // READER KERNEL auto reader_kernel_id = tt_metal::CreateKernel( diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_program_factory.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_program_factory.cpp index 085c775caa25..7f476e8f2ea8 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/running_statistics_program_factory.cpp @@ -30,18 +30,18 @@ void set_or_update_runtime_arguments( const RunningStatistics::tensor_args_t& tensor_args, RunningStatistics::tensor_return_value_t& c, F handle_args) { - const auto& [a, b, d, e] = tensor_args; + const auto& [batch_mean_tensor, batch_var_tensor, running_mean_tensor, running_var_tensor] = tensor_args; const auto momentum = operation_attributes.momentum; - const bool running_mean_has_value = d.has_value(); - const bool running_var_has_value = e.has_value(); + const bool running_mean_has_value = running_mean_tensor.has_value(); + const bool running_var_has_value = running_var_tensor.has_value(); - const auto ashape = a.padded_shape(); - const auto bshape = b.padded_shape(); + const auto ashape = batch_mean_tensor.padded_shape(); + const auto bshape = batch_var_tensor.padded_shape(); const auto cshape = c.padded_shape(); - const auto [aN, aC, aHt, aWt] = extract_shape_dims(a); - const auto [bN, bC, bHt, bWt] = extract_shape_dims(b); + const auto [aN, aC, aHt, aWt] = extract_shape_dims(batch_mean_tensor); + const auto [bN, bC, bHt, bWt] = extract_shape_dims(batch_var_tensor); const auto [cN, cC, cHt, cWt] = extract_shape_dims(c); uint32_t num_output_tiles = c.volume() / c.tensor_spec().tile().get_tile_hw(); @@ -55,6 +55,9 @@ void set_or_update_runtime_arguments( tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_output_tiles, row_major); auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major); + constexpr size_t num_reader_args = 11; + constexpr size_t num_writer_args = 13; + constexpr size_t num_kernel_args = 3; for (uint32_t i = 0, start_tile_id = 0; i < num_cores_total; i++) { const auto& core = cores[i]; @@ -64,9 +67,9 @@ void set_or_update_runtime_arguments( } else if (core_group_2.contains(core)) { num_tiles_per_core = num_tiles_per_core_group_2; } else { - handle_args(program, reader_kernel_id, core, std::array{0}); - handle_args(program, writer_kernel_id, core, std::array{0}); - handle_args(program, compute_kernel_id, core, std::array{0}); + handle_args(program, reader_kernel_id, core, std::array{0}); + handle_args(program, writer_kernel_id, core, std::array{0}); + handle_args(program, compute_kernel_id, core, std::array{0}); continue; } @@ -76,7 +79,7 @@ void set_or_update_runtime_arguments( pack_two_bfloat16_into_uint32({bfloat_scalar_momentum, bfloat_scalar_momentum}); std::array reader_runtime_args = { packed_scalar_momentum, - a.buffer()->address(), + batch_mean_tensor.buffer()->address(), start_tile_id, num_tiles_per_core, cHtWt, @@ -88,13 +91,13 @@ void set_or_update_runtime_arguments( cWt}; handle_args(program, reader_kernel_id, core, reader_runtime_args); - const auto running_mean_addr = running_mean_has_value ? d->buffer()->address() : 0; - const auto running_var_addr = running_var_has_value ? e->buffer()->address() : 0; + const auto running_mean_addr = running_mean_has_value ? running_mean_tensor->buffer()->address() : 0; + const auto running_var_addr = running_var_has_value ? running_var_tensor->buffer()->address() : 0; std::array writer_runtime_args = { - b.buffer()->address(), // batch var - running_mean_addr, // old running mean - running_var_addr, // old running var - c.buffer()->address(), // output + batch_var_tensor.buffer()->address(), // batch var + running_mean_addr, // old running mean + running_var_addr, // old running var + c.buffer()->address(), // output start_tile_id, num_tiles_per_core, cHtWt, @@ -128,22 +131,22 @@ RunningStatistics::RunningStatisticsProgramFactory::create( using namespace tt; using namespace tt::tt_metal; - const auto& [a, b, d, e] = tensor_args; + const auto& [batch_mean_tensor, batch_var_tensor, running_mean_tensor, running_var_tensor] = tensor_args; auto program = CreateProgram(); - auto* device = a.device(); + auto* device = batch_mean_tensor.device(); - const bool running_mean_has_value = d.has_value(); - const bool running_var_has_value = e.has_value(); + const bool running_mean_has_value = running_mean_tensor.has_value(); + const bool running_var_has_value = running_var_tensor.has_value(); - auto a_data_format = datatype_to_dataformat_converter(a.get_dtype()); - auto b_data_format = datatype_to_dataformat_converter(b.get_dtype()); + auto a_data_format = datatype_to_dataformat_converter(batch_mean_tensor.get_dtype()); + auto b_data_format = datatype_to_dataformat_converter(batch_var_tensor.get_dtype()); auto c_data_format = datatype_to_dataformat_converter(output.get_dtype()); - auto d_data_format = - running_mean_has_value ? datatype_to_dataformat_converter(d->get_dtype()) : DataFormat::Float16_b; - auto e_data_format = - running_var_has_value ? datatype_to_dataformat_converter(e->get_dtype()) : DataFormat::Float16_b; + auto d_data_format = running_mean_has_value ? datatype_to_dataformat_converter(running_mean_tensor->get_dtype()) + : DataFormat::Float16_b; + auto e_data_format = running_var_has_value ? datatype_to_dataformat_converter(running_var_tensor->get_dtype()) + : DataFormat::Float16_b; uint32_t a_single_tile_size = tt_metal::detail::TileSize(a_data_format); uint32_t b_single_tile_size = tt_metal::detail::TileSize(b_data_format); @@ -235,11 +238,13 @@ RunningStatistics::RunningStatisticsProgramFactory::create( auto [tmp3_cb, tmp3_cb_handle] = create_cb(tt::CBIndex::c_23, program, all_device_cores, b_single_tile_size, b_num_tiles_per_cb, b_data_format); - auto a_is_dram = static_cast(a.buffer()->buffer_type() == tt_metal::BufferType::DRAM); - auto b_is_dram = static_cast(b.buffer()->buffer_type() == tt_metal::BufferType::DRAM); + auto a_is_dram = static_cast(batch_mean_tensor.buffer()->buffer_type() == tt_metal::BufferType::DRAM); + auto b_is_dram = static_cast(batch_var_tensor.buffer()->buffer_type() == tt_metal::BufferType::DRAM); auto c_is_dram = static_cast(output.buffer()->buffer_type() == tt_metal::BufferType::DRAM); - const auto d_is_dram = running_mean_has_value and d->buffer()->buffer_type() == tt_metal::BufferType::DRAM; - const auto e_is_dram = running_var_has_value and e->buffer()->buffer_type() == tt_metal::BufferType::DRAM; + const auto d_is_dram = + running_mean_has_value and running_mean_tensor->buffer()->buffer_type() == tt_metal::BufferType::DRAM; + const auto e_is_dram = + running_var_has_value and running_var_tensor->buffer()->buffer_type() == tt_metal::BufferType::DRAM; // READER KERNEL auto reader_kernel_id = tt_metal::CreateKernel(