Skip to content

Commit

Permalink
#17758: Update Batch Norm Training mode kernels (#17733)
Browse files Browse the repository at this point in the history
### Ticket
#17758

### Problem description
[Comment
Link](#17587 (comment))

### What's changed
Updated BN to use compile-time arguments for buffer indexing, replacing
hardcoded values for better flexibility.

### Checklist
- [x] [All post-commit
tests](https://github.com/tenstorrent/tt-metal/actions/runs/13227397570)
- [x] [Blackhole post-commit
tests](https://github.com/tenstorrent/tt-metal/actions/runs/13227398013)
- [ ] [(Single-card) Tests for new models]()
- [x] [(Single-card) Demo
tests](https://github.com/tenstorrent/tt-metal/actions/runs/13227399196)
- [x] [(Single-card) Device perf
regressions](https://github.com/tenstorrent/tt-metal/actions/runs/13227399904)
- [x] [(Single-card) Model perf
tests](https://github.com/tenstorrent/tt-metal/actions/runs/13227400809)
  • Loading branch information
VirdhatchaniKN authored Feb 10, 2025
1 parent 65b32c9 commit 359ff79
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -171,18 +171,18 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
uint32_t b_num_tiles_per_cb = num_tiles_per_cb;

// Input buffers
auto [a_cb, a_cb_handle] = create_cb(
auto [input_tensor_cb, input_tensor_cb_handle] = create_cb(
tt::CBIndex::c_0, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format); // input
auto [b_cb, b_cb_handle] = create_cb(
auto [batch_mean_tensor_cb, batch_mean_tensor_cb_handle] = create_cb(
tt::CBIndex::c_1,
program,
all_device_cores,
b_single_tile_size,
b_num_tiles_per_cb,
b_data_format); // batch_mean
auto [c_cb, c_cb_handle] = create_cb(
auto [output_tensor_cb, output_tensor_cb_handle] = create_cb(
tt::CBIndex::c_2, program, all_device_cores, c_single_tile_size, num_tiles_per_cb, c_data_format); // output
auto [d_cb, d_cb_handle] = create_cb(
auto [batch_var_tensor_cb, batch_var_tensor_cb_handle] = create_cb(
tt::CBIndex::c_3,
program,
all_device_cores,
Expand All @@ -191,28 +191,28 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
d_data_format); // batch_var
auto [eps_cb, eps_cb_handle] = create_cb(
tt::CBIndex::c_4, program, all_device_cores, d_single_tile_size, b_num_tiles_per_cb, d_data_format); // eps
auto [e_cb, e_cb_handle] = create_cb(
tt::CBIndex::c_16, program, all_device_cores, e_single_tile_size, b_num_tiles_per_cb, e_data_format); // weight
auto [f_cb, f_cb_handle] = create_cb(
tt::CBIndex::c_18, program, all_device_cores, f_single_tile_size, b_num_tiles_per_cb, f_data_format); // bias
auto [weight_tensor_cb, weight_tensor_cb_handle] = create_cb(
tt::CBIndex::c_5, program, all_device_cores, e_single_tile_size, b_num_tiles_per_cb, e_data_format); // weight
auto [bias_tensor_cb, bias_tensor_cb_handle] = create_cb(
tt::CBIndex::c_6, program, all_device_cores, f_single_tile_size, b_num_tiles_per_cb, f_data_format); // bias

// Temporary buffers to store intermediate results
auto [den_cb, den_cb_handle] = create_cb(
tt::CBIndex::c_5,
tt::CBIndex::c_7,
program,
all_device_cores,
a_single_tile_size,
num_tiles_per_cb,
a_data_format); // to store 1/(sqrt(batch_var + eps))
auto [num_cb, num_cb_handle] = create_cb(
tt::CBIndex::c_6,
tt::CBIndex::c_8,
program,
all_device_cores,
a_single_tile_size,
num_tiles_per_cb,
a_data_format); // to store input - batch_mean
auto [temp_1_cb, temp_1_cb_handle] =
create_cb(tt::CBIndex::c_17, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format);
create_cb(tt::CBIndex::c_9, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format);

auto a_is_dram = static_cast<uint32_t>(input_tensor.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto b_is_dram = static_cast<uint32_t>(batch_mean_tensor.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
Expand All @@ -236,7 +236,7 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
program,
"ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/reader_batch_norm.cpp",
all_device_cores,
tt_metal::ReaderDataMovementConfig({a_is_dram}, std::move(reader_defines)));
tt_metal::ReaderDataMovementConfig({a_is_dram, input_tensor_cb, eps_cb}, std::move(reader_defines)));

// WRITER KERNEL
auto writer_defines = dataflow_defines;
Expand All @@ -253,41 +253,47 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
f_is_dram,
static_cast<uint32_t>(weight_has_value),
static_cast<uint32_t>(bias_has_value),
batch_mean_tensor_cb,
output_tensor_cb,
batch_var_tensor_cb,
weight_tensor_cb,
bias_tensor_cb,
},
std::move(writer_defines)));

// COMPUTE KERNEL
bool fp32_dest_acc_en = c_data_format == tt::DataFormat::UInt32 || c_data_format == tt::DataFormat::Int32 ||
c_data_format == tt::DataFormat::Float32;

uint32_t src_input_cb_index = tt::CBIndex::c_0;
uint32_t src_batch_mean_cb_index = tt::CBIndex::c_1;
uint32_t src_batch_var_cb_index = tt::CBIndex::c_3;
uint32_t src_eps_cb_index = tt::CBIndex::c_4;
uint32_t src_temp_den_cb_index = tt::CBIndex::c_5;
uint32_t src_temp_num_cb_index = tt::CBIndex::c_6;
uint32_t src_weight_cb_index = tt::CBIndex::c_16;
uint32_t src_temp_1_cb_index = tt::CBIndex::c_17;
uint32_t src_bias_cb_index = tt::CBIndex::c_18;

std::vector<UnpackToDestMode> unpack_to_dest_mode(NUM_CIRCULAR_BUFFERS, UnpackToDestMode::Default);
if (fp32_dest_acc_en) {
for (const auto cb_index :
{src_input_cb_index,
src_batch_mean_cb_index,
src_batch_var_cb_index,
src_temp_num_cb_index,
src_temp_den_cb_index,
src_eps_cb_index,
src_weight_cb_index,
src_temp_1_cb_index,
src_bias_cb_index}) {
{input_tensor_cb,
batch_mean_tensor_cb,
batch_var_tensor_cb,
eps_cb,
den_cb,
num_cb,
weight_tensor_cb,
temp_1_cb,
bias_tensor_cb}) {
unpack_to_dest_mode[cb_index] = UnpackToDestMode::UnpackToDestFp32;
}
}

std::vector<uint32_t> compute_kernel_args = {
static_cast<uint32_t>(weight_has_value), static_cast<uint32_t>(bias_has_value)};
static_cast<uint32_t>(weight_has_value),
static_cast<uint32_t>(bias_has_value),
input_tensor_cb,
batch_mean_tensor_cb,
output_tensor_cb,
batch_var_tensor_cb,
eps_cb,
den_cb,
num_cb,
weight_tensor_cb,
temp_1_cb,
bias_tensor_cb};
auto compute_kernel_id = tt_metal::CreateKernel(
program,
fmt::format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,17 +144,17 @@ void MAIN {
return;
}

constexpr auto cb_input = tt::CBIndex::c_0; // input
constexpr auto cb_batch_mean = tt::CBIndex::c_1; // batch_mean
constexpr auto cb_input = get_compile_time_arg_val(2); // input
constexpr auto cb_batch_mean = get_compile_time_arg_val(3); // batch_mean
constexpr auto cb_output_0 =
tt::CBIndex::c_2; // output -- > [(input - batch_mean)/(sqrt(batch_var + eps))] * weight
constexpr auto cb_batch_var = tt::CBIndex::c_3; // batch_var
constexpr auto cb_eps = tt::CBIndex::c_4; // eps
constexpr auto cb_den = tt::CBIndex::c_5; // 1/(sqrt(batch_var + eps))
constexpr auto cb_num = tt::CBIndex::c_6; // input - batch_mean
constexpr auto cb_weight = tt::CBIndex::c_16; // weight tensor
constexpr auto cb_tmp_1 = tt::CBIndex::c_17; // (input - batch_mean)/(sqrt(batch_var + eps))
constexpr auto cb_bias = tt::CBIndex::c_18; // bias tensor
get_compile_time_arg_val(4); // output -- > [(input - batch_mean)/(sqrt(batch_var + eps))] * weight
constexpr auto cb_batch_var = get_compile_time_arg_val(5); // batch_var
constexpr auto cb_eps = get_compile_time_arg_val(6); // eps
constexpr auto cb_den = get_compile_time_arg_val(7); // 1/(sqrt(batch_var + eps))
constexpr auto cb_num = get_compile_time_arg_val(8); // input - batch_mean
constexpr auto cb_weight = get_compile_time_arg_val(9); // weight tensor
constexpr auto cb_tmp_1 = get_compile_time_arg_val(10); // (input - batch_mean)/(sqrt(batch_var + eps))
constexpr auto cb_bias = get_compile_time_arg_val(11); // bias tensor

auto cb_bcast = cb_batch_mean;
auto cb_other = cb_input;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,17 +183,17 @@ void MAIN {
return;
}

constexpr auto cb_input = tt::CBIndex::c_0; // input
constexpr auto cb_batch_mean = tt::CBIndex::c_1; // batch_mean
constexpr auto cb_input = get_compile_time_arg_val(2); // input
constexpr auto cb_batch_mean = get_compile_time_arg_val(3); // batch_mean
constexpr auto cb_output_0 =
tt::CBIndex::c_2; // output -- > [(input - batch_mean)/(sqrt(batch_var + eps))] * weight
constexpr auto cb_batch_var = tt::CBIndex::c_3; // batch_var
constexpr auto cb_eps = tt::CBIndex::c_4; // eps
constexpr auto cb_den = tt::CBIndex::c_5; // 1/(sqrt(batch_var + eps))
constexpr auto cb_num = tt::CBIndex::c_6; // input - batch_mean
constexpr auto cb_weight = tt::CBIndex::c_16; // weight tensor
constexpr auto cb_tmp_1 = tt::CBIndex::c_17; // (input - batch_mean)/(sqrt(batch_var + eps))
constexpr auto cb_bias = tt::CBIndex::c_18; // bias tensor
get_compile_time_arg_val(4); // output -- > [(input - batch_mean)/(sqrt(batch_var + eps))] * weight
constexpr auto cb_batch_var = get_compile_time_arg_val(5); // batch_var
constexpr auto cb_eps = get_compile_time_arg_val(6); // eps
constexpr auto cb_den = get_compile_time_arg_val(7); // 1/(sqrt(batch_var + eps))
constexpr auto cb_num = get_compile_time_arg_val(8); // input - batch_mean
constexpr auto cb_weight = get_compile_time_arg_val(9); // weight tensor
constexpr auto cb_tmp_1 = get_compile_time_arg_val(10); // (input - batch_mean)/(sqrt(batch_var + eps))
constexpr auto cb_bias = get_compile_time_arg_val(11); // bias tensor

auto cb_bcast = cb_batch_mean;
auto cb_other = cb_input;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ void kernel_main() {

constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1;

constexpr auto cb_id_src = tt::CBIndex::c_0;
constexpr auto cb_id_src = get_compile_time_arg_val(1);
constexpr uint32_t onetile = 1;

const uint32_t src_tile_bytes = get_tile_size(cb_id_src);
Expand All @@ -35,7 +35,7 @@ void kernel_main() {
uint32_t start_c = start_remaining / HtWt;
uint32_t start_t = start_remaining % HtWt;

constexpr auto cb_id_eps = tt::CBIndex::c_4;
constexpr auto cb_id_eps = get_compile_time_arg_val(2);

union {
float f;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ void kernel_main() {
constexpr uint32_t onetile = 1;

// batch_mean
constexpr auto cb_id_src = tt::CBIndex::c_1;
constexpr auto cb_id_src = get_compile_time_arg_val(7);
constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1;
const uint32_t src_tile_bytes = get_tile_size(cb_id_src);
const DataFormat src_data_format = get_dataformat(cb_id_src);
Expand All @@ -33,7 +33,7 @@ void kernel_main() {
.bank_base_address = src_addr, .page_size = src_tile_bytes, .data_format = src_data_format};

// output
constexpr auto cb_id_dst = tt::CBIndex::c_2;
constexpr auto cb_id_dst = get_compile_time_arg_val(8);
constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1;
const uint32_t dst_tile_bytes = get_tile_size(cb_id_dst);
const DataFormat dst_data_format = get_dataformat(cb_id_dst);
Expand All @@ -42,7 +42,7 @@ void kernel_main() {
.bank_base_address = dst_addr, .page_size = dst_tile_bytes, .data_format = dst_data_format};

// batch_var
constexpr auto cb_id_batch_var = tt::CBIndex::c_3;
constexpr auto cb_id_batch_var = get_compile_time_arg_val(9);
constexpr bool batch_var_is_dram = get_compile_time_arg_val(2) == 1;
const uint32_t batch_var_tile_bytes = get_tile_size(cb_id_batch_var);
const DataFormat batch_var_data_format = get_dataformat(cb_id_batch_var);
Expand All @@ -51,7 +51,7 @@ void kernel_main() {
.bank_base_address = batch_var_addr, .page_size = batch_var_tile_bytes, .data_format = batch_var_data_format};

// weight
constexpr auto cb_id_weight = tt::CBIndex::c_16;
constexpr auto cb_id_weight = get_compile_time_arg_val(10);
constexpr bool weight_is_dram = get_compile_time_arg_val(3) == 1;
const uint32_t weight_tile_bytes = get_tile_size(cb_id_weight);
const DataFormat weight_data_format = get_dataformat(cb_id_weight);
Expand All @@ -60,7 +60,7 @@ void kernel_main() {
.bank_base_address = weight_addr, .page_size = weight_tile_bytes, .data_format = weight_data_format};

// bias
constexpr auto cb_id_bias = tt::CBIndex::c_18;
constexpr auto cb_id_bias = get_compile_time_arg_val(11);
constexpr bool bias_is_dram = get_compile_time_arg_val(4) == 1;
const uint32_t bias_tile_bytes = get_tile_size(cb_id_bias);
const DataFormat bias_data_format = get_dataformat(cb_id_bias);
Expand Down

0 comments on commit 359ff79

Please sign in to comment.