Skip to content

Commit

Permalink
#12253: Update with momentum, training, remove inference constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Jan 13, 2025
1 parent e258476 commit a505d15
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@ Tensor BatchNorm::invoke(
std::optional<Tensor> running_var,
const bool training,
const float eps,
const float momentum,
std::optional<Tensor> weight,
std::optional<Tensor> bias,
std::optional<Tensor> output,
const std::optional<MemoryConfig>& memory_config) {
// TODO: Implementation for training mode is in progress
TT_FATAL((!training), "Support currently provided for inference mode only");
TT_FATAL(
(running_mean.has_value() && running_var.has_value()),
(running_mean.has_value() && running_var.has_value() && (!training)),
"running_mean and running_var must be defined in evaluation mode");
return ttnn::prim::batch_norm(
input, running_mean.value(), running_var.value(), eps, weight, bias, output, memory_config);
input, running_mean.value(), running_var.value(), eps, momentum, training, weight, bias, output, memory_config);
}
} // namespace ttnn::operations::normalization
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ struct BatchNorm {
std::optional<Tensor> running_var = std::nullopt,
const bool training = false,
const float eps = 1e-05,
const float momentum = 0.1,
std::optional<Tensor> weight = std::nullopt,
std::optional<Tensor> bias = std::nullopt,
std::optional<Tensor> output = std::nullopt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ void bind_batch_norm_operation(pybind11::module& module) {
Keyword args:
eps (float, optional): Epsilon value. Defaults to `1e-05`.
momentum (float, optional): Momentum value. Defaults to `0.1`.
running_mean (ttnn.Tensor, optional): the running_mean of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`.
running_var (ttnn.Tensor, optional): the running_var of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`.
weight (ttnn.Tensor, optional): the weight or gamma value of shape `[1, C, 1, 1]`. Defaults to `None`.
Expand All @@ -44,6 +45,7 @@ void bind_batch_norm_operation(pybind11::module& module) {
py::arg("running_var") = std::nullopt,
py::arg("training") = false,
py::arg("eps") = 1e-05,
py::arg("momentum") = 0.1,
py::arg("weight") = std::nullopt,
py::arg("bias") = std::nullopt,
py::arg("output") = std::nullopt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,13 @@ std::tuple<BatchNormOperation::operation_attributes_t, BatchNormOperation::tenso
const Tensor& batch_mean,
const Tensor& batch_var,
const float eps,
const float momentum,
const bool training,
std::optional<Tensor> weight,
std::optional<Tensor> bias,
std::optional<Tensor> output,
const std::optional<MemoryConfig>& memory_config) {
operation_attributes_t operation_attributes{eps, memory_config.value_or(input.memory_config())};
operation_attributes_t operation_attributes{eps, momentum, training, memory_config.value_or(input.memory_config())};
tensor_args_t tensor_args{input, batch_mean, batch_var, std::move(weight), std::move(bias), std::move(output)};
return {operation_attributes, tensor_args};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ namespace ttnn::operations::normalization {
struct BatchNormOperation {
struct operation_attributes_t {
const float eps;
const float momentum;
const bool training;
const MemoryConfig memory_config;

DataType input_dtype;
Expand Down Expand Up @@ -66,6 +68,8 @@ struct BatchNormOperation {
const Tensor& batch_mean,
const Tensor& batch_var,
const float eps,
const float momentum,
const bool training,
std::optional<Tensor> weight,
std::optional<Tensor> bias,
std::optional<Tensor> output,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ void set_or_update_runtime_arguments(
F handle_args) {
const auto& [a, b, d, e, f, _] = tensor_args;
const auto eps = operation_attributes.eps;
const auto momentum = operation_attributes.momentum;

const bool weight_has_value = e.has_value();
const bool bias_has_value = f.has_value();
Expand Down Expand Up @@ -63,17 +64,21 @@ void set_or_update_runtime_arguments(
} else if (core_group_2.contains(core)) {
num_tiles_per_core = num_tiles_per_core_group_2;
} else {
handle_args(program, reader_kernel_id, core, std::array<uint32_t, 11>{0});
handle_args(program, reader_kernel_id, core, std::array<uint32_t, 12>{0});
handle_args(program, writer_kernel_id, core, std::array<uint32_t, 14>{0});
handle_args(program, compute_kernel_id, core, std::array<uint32_t, 3>{0});
continue;
}

uint32_t cHtWt = cHt * cWt;
class bfloat16 bfloat_scalar(eps);
uint32_t packed_scalar = pack_two_bfloat16_into_uint32({bfloat_scalar, bfloat_scalar});
class bfloat16 bfloat_scalar_eps(eps);
uint32_t packed_scalar_eps = pack_two_bfloat16_into_uint32({bfloat_scalar_eps, bfloat_scalar_eps});
class bfloat16 bfloat_scalar_momentum(momentum);
uint32_t packed_scalar_momentum =
pack_two_bfloat16_into_uint32({bfloat_scalar_momentum, bfloat_scalar_momentum});
std::array reader_runtime_args = {
packed_scalar,
packed_scalar_eps,
packed_scalar_momentum,
a.buffer()->address(),
start_tile_id,
num_tiles_per_core,
Expand Down Expand Up @@ -187,6 +192,13 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
tt::CBIndex::c_16, program, all_device_cores, e_single_tile_size, b_num_tiles_per_cb, e_data_format); // weight
auto [f_cb, f_cb_handle] = create_cb(
tt::CBIndex::c_18, program, all_device_cores, f_single_tile_size, b_num_tiles_per_cb, f_data_format); // bias
auto [momentum_cb, momentum_cb_handle] = create_cb(
tt::CBIndex::c_24,
program,
all_device_cores,
d_single_tile_size,
b_num_tiles_per_cb,
d_data_format); // momentum

// Temporary buffers to store intermediate results
auto [den_cb, den_cb_handle] = create_cb(
Expand Down Expand Up @@ -232,13 +244,16 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
e_is_dram,
f_is_dram,
static_cast<uint32_t>(weight_has_value),
static_cast<uint32_t>(bias_has_value)}));
static_cast<uint32_t>(bias_has_value),
static_cast<uint32_t>(operation_attributes.training)}));

// COMPUTE KERNEL
bool fp32_dest_acc_en = c_data_format == tt::DataFormat::UInt32 || c_data_format == tt::DataFormat::Int32 ||
c_data_format == tt::DataFormat::Float32;
std::vector<uint32_t> compute_kernel_args = {
static_cast<uint32_t>(weight_has_value), static_cast<uint32_t>(bias_has_value)};
static_cast<uint32_t>(weight_has_value),
static_cast<uint32_t>(bias_has_value),
static_cast<uint32_t>(operation_attributes.training)};
auto compute_kernel_id = tt_metal::CreateKernel(
program,
"ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ void MAIN {
uint32_t tile_start = get_arg_val<uint32_t>(2);
constexpr uint32_t weight_has_value = get_compile_time_arg_val(0) == 1;
constexpr uint32_t bias_has_value = get_compile_time_arg_val(1) == 1;
constexpr uint32_t is_training_mode = get_compile_time_arg_val(2) == 1;

if (num_tiles == 0) {
return;
Expand Down Expand Up @@ -116,6 +117,10 @@ void MAIN {
cb_pop_front(cb_den, 1);
cb_push_back(cb_affine_or_out, onetile);

if constexpr (is_training_mode) {
// update running stats here
}

if constexpr (weight_has_value) { // result = result * weight
cb_reserve_back(cb_scaled_output, onetile);
cb_wait_front(cb_affine_or_out, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@

void kernel_main() {
const auto eps = get_arg_val<uint32_t>(0);
uint32_t src_addr = get_arg_val<uint32_t>(1); // input tensor
uint32_t start_tile_id = get_arg_val<uint32_t>(2);
uint32_t num_tiles = get_arg_val<uint32_t>(3);
uint32_t HtWt = get_arg_val<uint32_t>(4);
uint32_t n_stride = get_arg_val<uint32_t>(5);
uint32_t c_stride = get_arg_val<uint32_t>(6);
uint32_t N = get_arg_val<uint32_t>(7);
uint32_t C = get_arg_val<uint32_t>(8);
const auto momentum = get_arg_val<uint32_t>(1);
uint32_t src_addr = get_arg_val<uint32_t>(2); // input tensor
uint32_t start_tile_id = get_arg_val<uint32_t>(3);
uint32_t num_tiles = get_arg_val<uint32_t>(4);
uint32_t HtWt = get_arg_val<uint32_t>(5);
uint32_t n_stride = get_arg_val<uint32_t>(6);
uint32_t c_stride = get_arg_val<uint32_t>(7);
uint32_t N = get_arg_val<uint32_t>(8);
uint32_t C = get_arg_val<uint32_t>(9);

constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1;

Expand All @@ -41,6 +42,12 @@ void kernel_main() {
fill_with_val_bfloat16(cb_id_eps, eps);
cb_push_back(cb_id_eps, onetile);

constexpr auto cb_id_momentum = tt::CBIndex::c_24;

cb_reserve_back(cb_id_momentum, onetile);
fill_with_val_bfloat16(cb_id_momentum, momentum);
cb_push_back(cb_id_momentum, onetile);

// Input tile offset
uint32_t tile_offset = start_n * n_stride + start_c * c_stride + start_t;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ void kernel_main() {

constexpr bool weight_has_value = get_compile_time_arg_val(5) == 1;
constexpr bool bias_has_value = get_compile_time_arg_val(6) == 1;
constexpr bool is_training_mode = get_compile_time_arg_val(7) == 1;

uint32_t tiles_per_batch = HtWt * C;
uint32_t start_n = start_tile_id / tiles_per_batch;
Expand Down Expand Up @@ -118,6 +119,10 @@ void kernel_main() {
cb_push_back(cb_id_bias, onetile);
}

// to read running stats value for updation
if constexpr (is_training_mode) {
}

for (uint32_t t = start_t; t < HtWt && num_tiles_written < num_tiles; ++t, ++num_tiles_written) {
// write a tile to dst
cb_wait_front(cb_id_dst, onetile);
Expand Down

0 comments on commit a505d15

Please sign in to comment.