Skip to content

Commit

Permalink
#12253: Update kernel to read, write running stats and test file
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Jan 29, 2025
1 parent f17dfea commit 68e68a5
Show file tree
Hide file tree
Showing 8 changed files with 288 additions and 37 deletions.
39 changes: 32 additions & 7 deletions tests/ttnn/unit_tests/operations/test_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,50 @@
torch.Size([3, 2, 64, 120]),
],
)
@pytest.mark.parametrize("training", [False])
@pytest.mark.parametrize(
"training, check_mean, check_var",
[
# (True, True, True),
# (True, True, False),
# (True, False, True),
(True, False, False),
(False, False, False), # xfail case
(False, True, False), # xfail case
(False, False, True), # xfail case
(False, True, True),
],
)
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05])
def test_batch_norm(input_shapes, training, weight, bias, eps, device):
@pytest.mark.parametrize("momentum", [0.1, 0.0, 2.3])
def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias, eps, momentum, device):
in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True)
mean_data, mean_tensor = (
data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if (not training) else (None, None)
)
var_data, var_tensor = (
data_gen_with_range_batch_norm(input_shapes, 4, 20, device) if (not training) else (None, None)
data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if (check_mean) else (None, None)
)
var_data, var_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 20, device) if (check_var) else (None, None)
weight_data, weight_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if weight else (None, None)
bias_data, bias_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if bias else (None, None)

if (not check_mean) or (not check_var):
pytest.xfail("running_mean and running_var must be defined in evaluation mode")

tt_output_tensor_on_device = ttnn.batch_norm(
input_tensor,
running_mean=mean_tensor,
running_var=var_tensor,
training=training,
eps=eps,
momentum=momentum,
weight=weight_tensor,
bias=bias_tensor,
)
tt_output = ttnn.to_torch(tt_output_tensor_on_device)

# tt_updated_mean = ttnn.to_torch(mean_tensor)
# tt_updated_var = ttnn.to_torch(var_tensor)

# ttnn.set_printoptions(profile="full")
# print("TT result : ", tt_output, tt_output.shape)
# torch.set_printoptions(precision=5, sci_mode=False)
Expand All @@ -60,9 +79,15 @@ def test_batch_norm(input_shapes, training, weight, bias, eps, device):
bias=bias_data,
training=training,
eps=eps,
momentum=momentum,
)
# print("Torch result : ",torch_result)
comp_pass = compare_results_batch_norm([tt_output], [torch_result])
comp_pass = compare_results_batch_norm([tt_output], [torch_result]) # Check BN Result
# if training :
# channels = input_shapes[1]
# comp_pass_1 = compare_results_batch_norm([tt_updated_mean], [mean_data.view(1, channels, 1, 1)]) # Check Updated running mean
# comp_pass_2 = compare_results_batch_norm([tt_updated_var], [var_data.view(1, channels, 1, 1)]) # Check Updated running var
# comp_pass = comp_pass and comp_pass_1 and comp_pass_2
assert comp_pass


Expand Down
50 changes: 48 additions & 2 deletions ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,27 @@
#include "batch_norm.hpp"

#include "device/batch_norm_device_operation.hpp"
#include "ttnn/operations/moreh/moreh_mean/device/moreh_mean_device_operation.hpp"
#include "ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp"

using namespace tt::tt_metal;

namespace ttnn::operations::normalization {

inline Tensor mean_NHW(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config) {
auto output_memory_config = memory_config.value_or(input_tensor.memory_config());
auto batch_mean = input_tensor;
ttnn::SmallVector<int64_t> dims = {0, 2, 3};
std::sort(dims.begin(), dims.end());
for (uint32_t i = dims.size() - 1; i > 0; i--) {
auto temp_output = ttnn::prim::moreh_mean(
batch_mean, dims[i], true, std::nullopt, std::nullopt, output_memory_config, std::nullopt);
batch_mean = temp_output;
}
return ttnn::prim::moreh_mean(
batch_mean, dims.front(), true, std::nullopt, std::nullopt, output_memory_config, std::nullopt);
}

Tensor BatchNorm::invoke(
const Tensor& input,
std::optional<Tensor> running_mean,
Expand All @@ -21,10 +37,40 @@ Tensor BatchNorm::invoke(
std::optional<Tensor> bias,
std::optional<Tensor> output,
const std::optional<MemoryConfig>& memory_config) {
if (training) {
Tensor batch_mean = mean_NHW(input, memory_config);
Tensor mean_sq = mean_NHW(ttnn::square(input, memory_config), memory_config);
Tensor batch_var =
ttnn::subtract(mean_sq, ttnn::square(batch_mean, memory_config), std::nullopt, memory_config);
return ttnn::prim::batch_norm(
input,
batch_mean,
batch_var,
eps,
momentum,
training,
weight,
bias,
running_mean,
running_var,
output,
memory_config);
}
TT_FATAL(
(running_mean.has_value() && running_var.has_value() && (!training)),
(running_mean.has_value() && running_var.has_value()),
"running_mean and running_var must be defined in evaluation mode");
return ttnn::prim::batch_norm(
input, running_mean.value(), running_var.value(), eps, momentum, training, weight, bias, output, memory_config);
input,
running_mean.value(),
running_var.value(),
eps,
momentum,
training,
weight,
bias,
std::nullopt,
std::nullopt,
output,
memory_config);
}
} // namespace ttnn::operations::normalization
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ void bind_batch_norm_operation(pybind11::module& module) {
module,
ttnn::batch_norm,
R"doc(
Applies Spatial Batch Normalization over each channel on :attr:`input_tensor`. Inputs must be must be tilized and interleaved. Currently support is provided for inference mode only.
Applies Spatial Batch Normalization over each channel on :attr:`input_tensor`. Inputs must be must be tilized and interleaved.
Args:
Expand All @@ -24,8 +24,8 @@ void bind_batch_norm_operation(pybind11::module& module) {
Keyword args:
eps (float, optional): Epsilon value. Defaults to `1e-05`.
momentum (float, optional): Momentum value. Defaults to `0.1`.
running_mean (ttnn.Tensor, optional): the running_mean of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`.
running_var (ttnn.Tensor, optional): the running_var of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`.
running_mean (ttnn.Tensor, optional): the running_mean of shape `[1, C, 1, 1]`, required in inference mode. When in training mode, this tensor is optional and the updated running mean value is stored in-place based on the inputs provided. Defaults to `None`.
running_var (ttnn.Tensor, optional): the running_var of shape `[1, C, 1, 1]`, required in inference mode. When in training mode, this tensor is optional and the updated running variance value is stored in-place based on the inputs provided. Defaults to `None`.
weight (ttnn.Tensor, optional): the weight or gamma value of shape `[1, C, 1, 1]`. Defaults to `None`.
bias (ttnn.Tensor, optional): the bias or beta value of shape `[1, C, 1, 1]`. Defaults to `None`.
training (bool, optional): Selection between training mode and inference (evaluation) mode. Defaults to `False` (Inference mode).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
namespace ttnn::operations::normalization {
void BatchNormOperation::validate_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
const auto& [input, batch_mean, batch_var, weight, bias, output] = tensor_args;
const auto& [input, batch_mean, batch_var, weight, bias, running_mean, running_var, output] = tensor_args;

check_tensor(input, "batch_norm", "input");
check_tensor(batch_mean, "batch_norm", "batch_mean");
check_tensor(batch_var, "batch_norm", "batch_var");
check_tensor(weight, "batch_norm", "weight");
check_tensor(bias, "batch_norm", "bias");
check_tensor(output, "batch_norm", "output");
check_tensor(running_mean, "batch_norm", "running_mean");
check_tensor(running_var, "batch_norm", "running_var");

// input (N, C, H, W)
auto C = input.get_logical_shape()[1];
Expand Down Expand Up @@ -45,6 +47,26 @@ void BatchNormOperation::validate_tensors(
TT_FATAL(bias.value().get_logical_shape()[1] == C, "bias_shape[1] must be the same as input's channel size.");
TT_FATAL(bias.value().get_logical_shape()[1] == C, "bias_shape[1] must be the same as input's channel size.");
}

// running_mean (1, C, 1, 1)
if (running_mean.has_value()) {
TT_FATAL(
running_mean.value().get_logical_shape()[1] == C,
"running_mean_shape[1] must be the same as input's channel size.");
TT_FATAL(
running_mean.value().get_logical_shape()[1] == C,
"running_mean_shape[1] must be the same as input's channel size.");
}

// running_var (1, C, 1, 1)
if (running_var.has_value()) {
TT_FATAL(
running_var.value().get_logical_shape()[1] == C,
"running_var_shape[1] must be the same as input's channel size.");
TT_FATAL(
running_var.value().get_logical_shape()[1] == C,
"running_var_shape[1] must be the same as input's channel size.");
}
}

BatchNormOperation::program_factory_t BatchNormOperation::select_program_factory(
Expand All @@ -54,7 +76,8 @@ BatchNormOperation::program_factory_t BatchNormOperation::select_program_factory

void BatchNormOperation::validate_on_program_cache_miss(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
const auto& [input, batch_mean, batch_var, weight, bias, output] = tensor_args;

const auto& [input, batch_mean, batch_var, weight, bias, running_mean, running_var, output] = tensor_args;

TT_FATAL(input.get_layout() == Layout::TILE, "Input tensor must be must be tilized");
TT_FATAL(
Expand Down Expand Up @@ -87,6 +110,20 @@ void BatchNormOperation::validate_on_program_cache_miss(
"bias tensor must be interleaved");
}

if (running_mean.has_value()) {
TT_FATAL(running_mean.value().get_layout() == Layout::TILE, "running_mean tensor must be tilized");
TT_FATAL(
running_mean.value().memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED,
"running_mean tensor must be interleaved");
}

if (running_var.has_value()) {
TT_FATAL(running_var.value().get_layout() == Layout::TILE, "running_var tensor must be tilized");
TT_FATAL(
running_var.value().memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED,
"running_var tensor must be interleaved");
}

validate_tensors(operation_attributes, tensor_args);
};

Expand Down Expand Up @@ -127,10 +164,20 @@ std::tuple<BatchNormOperation::operation_attributes_t, BatchNormOperation::tenso
const bool training,
std::optional<Tensor> weight,
std::optional<Tensor> bias,
std::optional<Tensor> running_mean,
std::optional<Tensor> running_var,
std::optional<Tensor> output,
const std::optional<MemoryConfig>& memory_config) {
operation_attributes_t operation_attributes{eps, momentum, training, memory_config.value_or(input.memory_config())};
tensor_args_t tensor_args{input, batch_mean, batch_var, std::move(weight), std::move(bias), std::move(output)};
tensor_args_t tensor_args{
input,
batch_mean,
batch_var,
std::move(weight),
std::move(bias),
std::move(running_mean),
std::move(running_var),
std::move(output)};
return {operation_attributes, tensor_args};
}
} // namespace ttnn::operations::normalization
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ struct BatchNormOperation {
const Tensor& batch_var;
std::optional<Tensor> weight;
std::optional<Tensor> bias;
std::optional<Tensor> running_mean;
std::optional<Tensor> running_var;
std::optional<Tensor> output;
};

Expand Down Expand Up @@ -72,6 +74,8 @@ struct BatchNormOperation {
const bool training,
std::optional<Tensor> weight,
std::optional<Tensor> bias,
std::optional<Tensor> running_mean,
std::optional<Tensor> running_var,
std::optional<Tensor> output,
const std::optional<MemoryConfig>& memory_config);
};
Expand Down
Loading

0 comments on commit 68e68a5

Please sign in to comment.