Skip to content

Commit

Permalink
#0: Write updated running stats
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Jan 10, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent e3e7457 commit 9544031
Showing 5 changed files with 59 additions and 11 deletions.
27 changes: 20 additions & 7 deletions tests/ttnn/unit_tests/operations/test_batch_norm.py
Original file line number Diff line number Diff line change
@@ -24,21 +24,34 @@
torch.Size([3, 2, 64, 120]),
],
)
@pytest.mark.parametrize("training", [True, False])
@pytest.mark.parametrize(
"training, check_mean, check_var",
[
# (True, True, True),
# (True, True, False),
# (True, False, True),
(True, False, False),
(False, False, False), # xfail case
(False, True, False), # xfail case
(False, False, True), # xfail case
(False, True, True),
],
)
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05])
def test_batch_norm(input_shapes, training, weight, bias, eps, device):
@pytest.mark.parametrize("eps", [1.05])
def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias, eps, device):
in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True)
mean_data, mean_tensor = (
data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if (not training) else (None, None)
)
var_data, var_tensor = (
data_gen_with_range_batch_norm(input_shapes, 4, 20, device) if (not training) else (None, None)
data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if (check_mean) else (None, None)
)
var_data, var_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 20, device) if (check_var) else (None, None)
weight_data, weight_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if weight else (None, None)
bias_data, bias_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if bias else (None, None)

if (not check_mean) or (not check_var):
pytest.xfail("running_mean and running_var must be defined in evaluation mode")

tt_output_tensor_on_device = ttnn.batch_norm(
input_tensor,
running_mean=mean_tensor,
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@ void bind_batch_norm_operation(pybind11::module& module) {
module,
ttnn::batch_norm,
R"doc(
Applies Spatial Batch Normalization over each channel on :attr:`input_tensor`. Inputs must be must be tilized and interleaved. Currently support is provided for inference mode only.
Applies Spatial Batch Normalization over each channel on :attr:`input_tensor`. Inputs must be must be tilized and interleaved.
Args:
@@ -24,8 +24,8 @@ void bind_batch_norm_operation(pybind11::module& module) {
Keyword args:
eps (float, optional): Epsilon value. Defaults to `1e-05`.
momentum (float, optional): Momentum value. Defaults to `0.1`.
running_mean (ttnn.Tensor, optional): the running_mean of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`.
running_var (ttnn.Tensor, optional): the running_var of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`.
running_mean (ttnn.Tensor, optional): the running_mean of shape `[1, C, 1, 1]`, required in inference mode. When in training mode, this tensor is optional and the updated running mean value is stored in-place based on the inputs provided. Defaults to `None`.
running_var (ttnn.Tensor, optional): the running_var of shape `[1, C, 1, 1]`, required in inference mode. When in training mode, this tensor is optional and the updated running variance value is stored in-place based on the inputs provided. Defaults to `None`.
weight (ttnn.Tensor, optional): the weight or gamma value of shape `[1, C, 1, 1]`. Defaults to `None`.
bias (ttnn.Tensor, optional): the bias or beta value of shape `[1, C, 1, 1]`. Defaults to `None`.
training (bool, optional): Selection between training mode and inference (evaluation) mode. Defaults to `False` (Inference mode).
Original file line number Diff line number Diff line change
@@ -247,6 +247,20 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
a_data_format); // to store input - batch_mean
auto [temp_1_cb, temp_1_cb_handle] =
create_cb(tt::CBIndex::c_17, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format);
auto [updated_m_cb, updated_m_cb_handle] = create_cb(
tt::CBIndex::c_27,
program,
all_device_cores,
g_single_tile_size,
b_num_tiles_per_cb,
g_data_format); // updated running mean
auto [updated_v_cb, updated_v_cb_handle] = create_cb(
tt::CBIndex::c_28,
program,
all_device_cores,
h_single_tile_size,
b_num_tiles_per_cb,
h_data_format); // updated running var

auto a_is_dram = static_cast<uint32_t>(a.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
auto b_is_dram = static_cast<uint32_t>(b.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
Original file line number Diff line number Diff line change
@@ -60,6 +60,9 @@ void MAIN {
constexpr auto cb_bias = tt::CBIndex::c_18; // bias tensor
constexpr auto cb_old_running_mean = tt::CBIndex::c_25; // old running mean tensor
constexpr auto cb_old_running_var = tt::CBIndex::c_26; // old running var tensor
constexpr auto cb_updated_running_mean = tt::CBIndex::c_27; // updated running mean tensor
constexpr auto cb_updated_running_var = tt::CBIndex::c_28; // updated running var tensor
constexpr auto cb_momentum = tt::CBIndex::c_24; // momentum

auto cb_bcast = cb_batch_mean;
auto cb_other = cb_input;
@@ -122,7 +125,7 @@ void MAIN {
cb_push_back(cb_affine_or_out, onetile);

if constexpr (is_training_mode) {
// update running stats here
// updated running stats
if constexpr (old_running_mean_has_value) {
}

Original file line number Diff line number Diff line change
@@ -98,6 +98,8 @@ void kernel_main() {

constexpr bool old_running_mean_has_value = get_compile_time_arg_val(10) == 1;
constexpr bool old_running_var_has_value = get_compile_time_arg_val(11) == 1;
constexpr auto cb_id_updated_running_mean = tt::CBIndex::c_27;
constexpr auto cb_id_updated_running_var = tt::CBIndex::c_28;

uint32_t tiles_per_batch = HtWt * C;
uint32_t start_n = start_tile_id / tiles_per_batch;
@@ -149,21 +151,37 @@ void kernel_main() {
// to read running stats value for updation
if constexpr (is_training_mode) {
if constexpr (old_running_mean_has_value) {
// read data
cb_reserve_back(cb_id_old_running_mean, onetile);
uint32_t l1_old_running_mean_write_addr = get_write_ptr(cb_id_old_running_mean);
noc_async_read_tile(tile_offset, old_running_mean, l1_old_running_mean_write_addr);
noc_async_read_barrier();
fill_tile_with_first_element_bfloat16(cb_id_old_running_mean);
cb_push_back(cb_id_old_running_mean, onetile);

// write data
cb_wait_front(cb_id_updated_running_mean, onetile);
uint32_t l1_write_updated_mean_addr = get_read_ptr(cb_id_updated_running_mean);
noc_async_write_tile(tile_offset, old_running_mean, l1_write_updated_mean_addr);
noc_async_write_barrier();
cb_pop_front(cb_id_updated_running_mean, onetile);
}

if constexpr (old_running_var_has_value) {
// read data
cb_reserve_back(cb_id_old_running_var, onetile);
uint32_t l1_old_running_var_write_addr = get_write_ptr(cb_id_old_running_var);
noc_async_read_tile(tile_offset, old_running_var, l1_old_running_var_write_addr);
noc_async_read_barrier();
fill_tile_with_first_element_bfloat16(cb_id_old_running_var);
cb_push_back(cb_id_old_running_var, onetile);

// write data
cb_wait_front(cb_id_updated_running_var, onetile);
uint32_t l1_write_updated_var_addr = get_read_ptr(cb_id_updated_running_var);
noc_async_write_tile(tile_offset, old_running_var, l1_write_updated_var_addr);
noc_async_write_barrier();
cb_pop_front(cb_id_updated_running_var, onetile);
}
}

0 comments on commit 9544031

Please sign in to comment.