Skip to content

Commit

Permalink
#0: Update kernel with running stats
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Jan 9, 2025
1 parent 340a2cf commit d9bc623
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 27 deletions.
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/operations/test_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
torch.Size([3, 2, 64, 120]),
],
)
@pytest.mark.parametrize("training", [False])
@pytest.mark.parametrize("training", [True, False])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05])
Expand Down
50 changes: 48 additions & 2 deletions ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,27 @@
#include "batch_norm.hpp"

#include "device/batch_norm_device_operation.hpp"
#include "ttnn/operations/moreh/moreh_mean/device/moreh_mean_device_operation.hpp"
#include "ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp"

using namespace tt::tt_metal;

namespace ttnn::operations::normalization {

inline Tensor mean_NHW(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config) {
auto output_memory_config = memory_config.value_or(input_tensor.memory_config());
auto batch_mean = input_tensor;
ttnn::SmallVector<int64_t> dims = {0, 2, 3};
std::sort(dims.begin(), dims.end());
for (uint32_t i = dims.size() - 1; i > 0; i--) {
auto temp_output = ttnn::prim::moreh_mean(
batch_mean, dims[i], true, std::nullopt, std::nullopt, output_memory_config, std::nullopt);
batch_mean = temp_output;
}
return ttnn::prim::moreh_mean(
batch_mean, dims.front(), true, std::nullopt, std::nullopt, output_memory_config, std::nullopt);
}

Tensor BatchNorm::invoke(
const Tensor& input,
std::optional<Tensor> running_mean,
Expand All @@ -21,10 +37,40 @@ Tensor BatchNorm::invoke(
std::optional<Tensor> bias,
std::optional<Tensor> output,
const std::optional<MemoryConfig>& memory_config) {
if (training) {
Tensor batch_mean = mean_NHW(input, memory_config);
Tensor mean_sq = mean_NHW(ttnn::square(input, memory_config), memory_config);
Tensor batch_var =
ttnn::subtract(mean_sq, ttnn::square(batch_mean, memory_config), std::nullopt, memory_config);
return ttnn::prim::batch_norm(
input,
batch_mean,
batch_var,
eps,
momentum,
training,
weight,
bias,
running_mean,
running_var,
output,
memory_config);
}
TT_FATAL(
(running_mean.has_value() && running_var.has_value() && (!training)),
(running_mean.has_value() && running_var.has_value()),
"running_mean and running_var must be defined in evaluation mode");
return ttnn::prim::batch_norm(
input, running_mean.value(), running_var.value(), eps, momentum, training, weight, bias, output, memory_config);
input,
running_mean.value(),
running_var.value(),
eps,
momentum,
training,
weight,
bias,
std::nullopt,
std::nullopt,
output,
memory_config);
}
} // namespace ttnn::operations::normalization
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
namespace ttnn::operations::normalization {
void BatchNormOperation::validate_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
const auto& [input, batch_mean, batch_var, weight, bias, output] = tensor_args;
const auto& [input, batch_mean, batch_var, weight, bias, running_mean, running_var, output] = tensor_args;

check_tensor(input, "batch_norm", "input");
check_tensor(batch_mean, "batch_norm", "batch_mean");
check_tensor(batch_var, "batch_norm", "batch_var");
check_tensor(weight, "batch_norm", "weight");
check_tensor(bias, "batch_norm", "bias");
check_tensor(output, "batch_norm", "output");
check_tensor(running_mean, "batch_norm", "running_mean");
check_tensor(running_var, "batch_norm", "running_var");

// input (N, C, H, W)
auto C = input.get_logical_shape()[1];
Expand Down Expand Up @@ -45,6 +47,26 @@ void BatchNormOperation::validate_tensors(
TT_FATAL(bias.value().get_logical_shape()[1] == C, "bias_shape[1] must be the same as input's channel size.");
TT_FATAL(bias.value().get_logical_shape()[1] == C, "bias_shape[1] must be the same as input's channel size.");
}

// running_mean (1, C, 1, 1)
if (running_mean.has_value()) {
TT_FATAL(
running_mean.value().get_logical_shape()[1] == C,
"running_mean_shape[1] must be the same as input's channel size.");
TT_FATAL(
running_mean.value().get_logical_shape()[1] == C,
"running_mean_shape[1] must be the same as input's channel size.");
}

// running_var (1, C, 1, 1)
if (running_var.has_value()) {
TT_FATAL(
running_var.value().get_logical_shape()[1] == C,
"running_var_shape[1] must be the same as input's channel size.");
TT_FATAL(
running_var.value().get_logical_shape()[1] == C,
"running_var_shape[1] must be the same as input's channel size.");
}
}

BatchNormOperation::program_factory_t BatchNormOperation::select_program_factory(
Expand All @@ -55,7 +77,7 @@ BatchNormOperation::program_factory_t BatchNormOperation::select_program_factory
void BatchNormOperation::validate_on_program_cache_miss(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
// We don't support sharding for now
const auto& [input, batch_mean, batch_var, weight, bias, output] = tensor_args;
const auto& [input, batch_mean, batch_var, weight, bias, running_mean, running_var, output] = tensor_args;

TT_FATAL(input.get_layout() == Layout::TILE, "Input tensor must be must be tilized");
TT_FATAL(
Expand Down Expand Up @@ -88,6 +110,20 @@ void BatchNormOperation::validate_on_program_cache_miss(
"bias tensor must be interleaved");
}

if (running_mean.has_value()) {
TT_FATAL(running_mean.value().get_layout() == Layout::TILE, "running_mean tensor must be tilized");
TT_FATAL(
running_mean.value().memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED,
"running_mean tensor must be interleaved");
}

if (running_var.has_value()) {
TT_FATAL(running_var.value().get_layout() == Layout::TILE, "running_var tensor must be tilized");
TT_FATAL(
running_var.value().memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED,
"running_var tensor must be interleaved");
}

validate_tensors(operation_attributes, tensor_args);
};

Expand Down Expand Up @@ -128,10 +164,20 @@ std::tuple<BatchNormOperation::operation_attributes_t, BatchNormOperation::tenso
const bool training,
std::optional<Tensor> weight,
std::optional<Tensor> bias,
std::optional<Tensor> running_mean,
std::optional<Tensor> running_var,
std::optional<Tensor> output,
const std::optional<MemoryConfig>& memory_config) {
operation_attributes_t operation_attributes{eps, momentum, training, memory_config.value_or(input.memory_config())};
tensor_args_t tensor_args{input, batch_mean, batch_var, std::move(weight), std::move(bias), std::move(output)};
tensor_args_t tensor_args{
input,
batch_mean,
batch_var,
std::move(weight),
std::move(bias),
std::move(running_mean),
std::move(running_var),
std::move(output)};
return {operation_attributes, tensor_args};
}
} // namespace ttnn::operations::normalization
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ struct BatchNormOperation {
std::optional<Tensor> weight;
std::optional<Tensor> bias;
std::optional<Tensor> output;
std::optional<Tensor> running_mean;
std::optional<Tensor> running_var;
};

using spec_return_value_t = TensorSpec;
Expand Down Expand Up @@ -72,6 +74,8 @@ struct BatchNormOperation {
const bool training,
std::optional<Tensor> weight,
std::optional<Tensor> bias,
std::optional<Tensor> running_mean,
std::optional<Tensor> running_var,
std::optional<Tensor> output,
const std::optional<MemoryConfig>& memory_config);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ void set_or_update_runtime_arguments(
const BatchNormOperation::tensor_args_t& tensor_args,
BatchNormOperation::tensor_return_value_t& c,
F handle_args) {
const auto& [a, b, d, e, f, _] = tensor_args;
const auto& [a, b, d, e, f, g, h, _] = tensor_args;
const auto eps = operation_attributes.eps;
const auto momentum = operation_attributes.momentum;

const bool weight_has_value = e.has_value();
const bool bias_has_value = f.has_value();
const bool running_mean_has_value = g.has_value();
const bool running_var_has_value = h.has_value();
const bool is_training_mode = operation_attributes.training;

const auto ashape = a.padded_shape();
const auto bshape = b.padded_shape();
Expand Down Expand Up @@ -65,7 +68,7 @@ void set_or_update_runtime_arguments(
num_tiles_per_core = num_tiles_per_core_group_2;
} else {
handle_args(program, reader_kernel_id, core, std::array<uint32_t, 12>{0});
handle_args(program, writer_kernel_id, core, std::array<uint32_t, 14>{0});
handle_args(program, writer_kernel_id, core, std::array<uint32_t, 16>{0});
handle_args(program, compute_kernel_id, core, std::array<uint32_t, 3>{0});
continue;
}
Expand Down Expand Up @@ -93,11 +96,15 @@ void set_or_update_runtime_arguments(

const auto weight_addr = weight_has_value ? e->buffer()->address() : 0;
const auto bias_addr = bias_has_value ? f->buffer()->address() : 0;
const auto running_mean_addr = is_training_mode and running_mean_has_value ? g->buffer()->address() : 0;
const auto running_var_addr = is_training_mode and running_var_has_value ? h->buffer()->address() : 0;
std::array writer_runtime_args = {
b.buffer()->address(), // batch mean
d.buffer()->address(), // batch var
weight_addr, // weight
bias_addr, // bias
running_mean_addr, // old running mean
running_var_addr, // old running var
c.buffer()->address(), // output
start_tile_id,
num_tiles_per_core,
Expand Down Expand Up @@ -131,28 +138,37 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
using namespace tt;
using namespace tt::tt_metal;

const auto& [a, b, d, e, f, _] = tensor_args;
const auto& [a, b, d, e, f, g, h, _] = tensor_args;

auto program = CreateProgram();

auto* device = a.device();

const bool weight_has_value = e.has_value();
const bool bias_has_value = f.has_value();
const bool running_mean_has_value = g.has_value();
const bool running_var_has_value = h.has_value();
const bool is_training_mode = operation_attributes.training;

auto a_data_format = datatype_to_dataformat_converter(a.get_dtype());
auto b_data_format = datatype_to_dataformat_converter(b.get_dtype());
auto c_data_format = datatype_to_dataformat_converter(output.get_dtype());
auto d_data_format = datatype_to_dataformat_converter(d.get_dtype());
auto e_data_format = weight_has_value ? datatype_to_dataformat_converter(e->get_dtype()) : DataFormat::Float16_b;
auto f_data_format = bias_has_value ? datatype_to_dataformat_converter(f->get_dtype()) : DataFormat::Float16_b;
auto g_data_format = is_training_mode and running_mean_has_value ? datatype_to_dataformat_converter(g->get_dtype())
: DataFormat::Float16_b;
auto h_data_format = is_training_mode and running_var_has_value ? datatype_to_dataformat_converter(h->get_dtype())
: DataFormat::Float16_b;

uint32_t a_single_tile_size = tt_metal::detail::TileSize(a_data_format);
uint32_t b_single_tile_size = tt_metal::detail::TileSize(b_data_format);
uint32_t c_single_tile_size = tt_metal::detail::TileSize(c_data_format);
uint32_t d_single_tile_size = tt_metal::detail::TileSize(d_data_format);
uint32_t e_single_tile_size = tt_metal::detail::TileSize(e_data_format);
uint32_t f_single_tile_size = tt_metal::detail::TileSize(f_data_format);
uint32_t g_single_tile_size = tt_metal::detail::TileSize(g_data_format);
uint32_t h_single_tile_size = tt_metal::detail::TileSize(h_data_format);

uint32_t num_output_tiles = output.volume() / output.tensor_spec().tile().get_tile_hw();

Expand Down Expand Up @@ -199,6 +215,20 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
d_single_tile_size,
b_num_tiles_per_cb,
d_data_format); // momentum
auto [g_cb, g_cb_handle] = create_cb(
tt::CBIndex::c_25,
program,
all_device_cores,
g_single_tile_size,
b_num_tiles_per_cb,
g_data_format); // old running mean
auto [h_cb, h_cb_handle] = create_cb(
tt::CBIndex::c_26,
program,
all_device_cores,
h_single_tile_size,
b_num_tiles_per_cb,
h_data_format); // old running var

// Temporary buffers to store intermediate results
auto [den_cb, den_cb_handle] = create_cb(
Expand All @@ -224,6 +254,10 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
auto d_is_dram = static_cast<uint32_t>(d.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
const auto e_is_dram = weight_has_value and e->buffer()->buffer_type() == tt_metal::BufferType::DRAM;
const auto f_is_dram = bias_has_value and f->buffer()->buffer_type() == tt_metal::BufferType::DRAM;
const auto g_is_dram =
is_training_mode and running_mean_has_value and g->buffer()->buffer_type() == tt_metal::BufferType::DRAM;
const auto h_is_dram =
is_training_mode and running_var_has_value and h->buffer()->buffer_type() == tt_metal::BufferType::DRAM;

// READER KERNEL
auto reader_kernel_id = tt_metal::CreateKernel(
Expand All @@ -237,23 +271,30 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
program,
"ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp",
all_device_cores,
tt_metal::WriterDataMovementConfig(
{b_is_dram,
c_is_dram,
d_is_dram,
e_is_dram,
f_is_dram,
static_cast<uint32_t>(weight_has_value),
static_cast<uint32_t>(bias_has_value),
static_cast<uint32_t>(operation_attributes.training)}));
tt_metal::WriterDataMovementConfig({
b_is_dram,
c_is_dram,
d_is_dram,
e_is_dram,
f_is_dram,
static_cast<uint32_t>(weight_has_value),
static_cast<uint32_t>(bias_has_value),
static_cast<uint32_t>(operation_attributes.training),
g_is_dram,
h_is_dram,
static_cast<uint32_t>(running_mean_has_value),
static_cast<uint32_t>(running_var_has_value),
}));

// COMPUTE KERNEL
bool fp32_dest_acc_en = c_data_format == tt::DataFormat::UInt32 || c_data_format == tt::DataFormat::Int32 ||
c_data_format == tt::DataFormat::Float32;
std::vector<uint32_t> compute_kernel_args = {
static_cast<uint32_t>(weight_has_value),
static_cast<uint32_t>(bias_has_value),
static_cast<uint32_t>(operation_attributes.training)};
static_cast<uint32_t>(operation_attributes.training),
static_cast<uint32_t>(running_mean_has_value),
static_cast<uint32_t>(running_var_has_value)};
auto compute_kernel_id = tt_metal::CreateKernel(
program,
"ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ void MAIN {
constexpr uint32_t weight_has_value = get_compile_time_arg_val(0) == 1;
constexpr uint32_t bias_has_value = get_compile_time_arg_val(1) == 1;
constexpr uint32_t is_training_mode = get_compile_time_arg_val(2) == 1;
constexpr uint32_t old_running_mean_has_value = get_compile_time_arg_val(3) == 1;
constexpr uint32_t old_running_var_has_value = get_compile_time_arg_val(4) == 1;

if (num_tiles == 0) {
return;
Expand All @@ -56,6 +58,8 @@ void MAIN {
constexpr auto cb_weight = tt::CBIndex::c_16; // weight tensor
constexpr auto cb_tmp_1 = tt::CBIndex::c_17; // (input - batch_mean)/(sqrt(batch_var + eps))
constexpr auto cb_bias = tt::CBIndex::c_18; // bias tensor
constexpr auto cb_old_running_mean = tt::CBIndex::c_25; // old running mean tensor
constexpr auto cb_old_running_var = tt::CBIndex::c_26; // old running var tensor

auto cb_bcast = cb_batch_mean;
auto cb_other = cb_input;
Expand Down Expand Up @@ -119,6 +123,11 @@ void MAIN {

if constexpr (is_training_mode) {
// update running stats here
if constexpr (old_running_mean_has_value) {
}

if constexpr (old_running_var_has_value) {
}
}

if constexpr (weight_has_value) { // result = result * weight
Expand Down
Loading

0 comments on commit d9bc623

Please sign in to comment.