Skip to content

Commit

Permalink
#12253: Remove updation of running stats
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Jan 22, 2025
1 parent e5d2182 commit 29b1fed
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 329 deletions.
42 changes: 21 additions & 21 deletions tests/ttnn/unit_tests/operations/test_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,27 +83,27 @@ def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias,
momentum=momentum,
)
comp_pass = compare_results_batch_norm([tt_output], [torch_result]) # Check BN Result
if training:
channels = input_shapes[1]
if check_mean:
comp_pass_1 = compare_results_batch_norm(
[tt_updated_mean], [mean_data.view(1, channels, 1, 1)]
) # Check Updated running mean
else:
if tt_updated_mean is None:
comp_pass_1 = True
else:
comp_pass_1 = False
if check_var:
comp_pass_2 = compare_results_batch_norm(
[tt_updated_var], [var_data.view(1, channels, 1, 1)]
) # Check Updated running var
else:
if tt_updated_var is None:
comp_pass_2 = True
else:
comp_pass_2 = False
comp_pass = comp_pass and comp_pass_1 and comp_pass_2
# if training:
# channels = input_shapes[1]
# if check_mean:
# comp_pass_1 = compare_results_batch_norm(
# [tt_updated_mean], [mean_data.view(1, channels, 1, 1)]
# ) # Check Updated running mean
# else:
# if tt_updated_mean is None:
# comp_pass_1 = True
# else:
# comp_pass_1 = False
# if check_var:
# comp_pass_2 = compare_results_batch_norm(
# [tt_updated_var], [var_data.view(1, channels, 1, 1)]
# ) # Check Updated running var
# else:
# if tt_updated_var is None:
# comp_pass_2 = True
# else:
# comp_pass_2 = False
# comp_pass = comp_pass and comp_pass_1 and comp_pass_2

assert comp_pass

Expand Down
29 changes: 3 additions & 26 deletions ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include "device/batch_norm_device_operation.hpp"
#include "ttnn/operations/moreh/moreh_mean/device/moreh_mean_device_operation.hpp"
#include "ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp"
#include "ttnn/operations/eltwise/unary/device/unary_composite_op.hpp"

using namespace tt::tt_metal;

Expand Down Expand Up @@ -42,35 +42,12 @@ Tensor BatchNorm::invoke(
Tensor mean_sq = mean_NHW(ttnn::square(input, memory_config), memory_config);
Tensor batch_var =
ttnn::subtract(mean_sq, ttnn::square(batch_mean, memory_config), std::nullopt, memory_config);
return ttnn::prim::batch_norm(
input,
batch_mean,
batch_var,
eps,
momentum,
training,
weight,
bias,
running_mean,
running_var,
output,
memory_config);
return ttnn::prim::batch_norm(input, batch_mean, batch_var, eps, weight, bias, output, memory_config);
}
TT_FATAL(
(running_mean.has_value() && running_var.has_value()),
"running_mean and running_var must be defined in evaluation mode");
return ttnn::prim::batch_norm(
input,
running_mean.value(),
running_var.value(),
eps,
momentum,
training,
weight,
bias,
std::nullopt,
std::nullopt,
output,
memory_config);
input, running_mean.value(), running_var.value(), eps, weight, bias, output, memory_config);
}
} // namespace ttnn::operations::normalization
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,14 @@
namespace ttnn::operations::normalization {
void BatchNormOperation::validate_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
const auto& [input, batch_mean, batch_var, weight, bias, running_mean, running_var, output] = tensor_args;
const auto& [input, batch_mean, batch_var, weight, bias, output] = tensor_args;

check_tensor(input, "batch_norm", "input");
check_tensor(batch_mean, "batch_norm", "batch_mean");
check_tensor(batch_var, "batch_norm", "batch_var");
check_tensor(weight, "batch_norm", "weight");
check_tensor(bias, "batch_norm", "bias");
check_tensor(output, "batch_norm", "output");
check_tensor(running_mean, "batch_norm", "running_mean");
check_tensor(running_var, "batch_norm", "running_var");

// input (N, C, H, W)
auto C = input.get_logical_shape()[1];
Expand Down Expand Up @@ -47,26 +45,6 @@ void BatchNormOperation::validate_tensors(
TT_FATAL(bias.value().get_logical_shape()[1] == C, "bias_shape[1] must be the same as input's channel size.");
TT_FATAL(bias.value().get_logical_shape()[1] == C, "bias_shape[1] must be the same as input's channel size.");
}

// running_mean (1, C, 1, 1)
if (running_mean.has_value()) {
TT_FATAL(
running_mean.value().get_logical_shape()[1] == C,
"running_mean_shape[1] must be the same as input's channel size.");
TT_FATAL(
running_mean.value().get_logical_shape()[1] == C,
"running_mean_shape[1] must be the same as input's channel size.");
}

// running_var (1, C, 1, 1)
if (running_var.has_value()) {
TT_FATAL(
running_var.value().get_logical_shape()[1] == C,
"running_var_shape[1] must be the same as input's channel size.");
TT_FATAL(
running_var.value().get_logical_shape()[1] == C,
"running_var_shape[1] must be the same as input's channel size.");
}
}

BatchNormOperation::program_factory_t BatchNormOperation::select_program_factory(
Expand All @@ -76,8 +54,7 @@ BatchNormOperation::program_factory_t BatchNormOperation::select_program_factory

void BatchNormOperation::validate_on_program_cache_miss(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {

const auto& [input, batch_mean, batch_var, weight, bias, running_mean, running_var, output] = tensor_args;
const auto& [input, batch_mean, batch_var, weight, bias, output] = tensor_args;

TT_FATAL(input.get_layout() == Layout::TILE, "Input tensor must be must be tilized");
TT_FATAL(
Expand Down Expand Up @@ -110,20 +87,6 @@ void BatchNormOperation::validate_on_program_cache_miss(
"bias tensor must be interleaved");
}

if (running_mean.has_value()) {
TT_FATAL(running_mean.value().get_layout() == Layout::TILE, "running_mean tensor must be tilized");
TT_FATAL(
running_mean.value().memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED,
"running_mean tensor must be interleaved");
}

if (running_var.has_value()) {
TT_FATAL(running_var.value().get_layout() == Layout::TILE, "running_var tensor must be tilized");
TT_FATAL(
running_var.value().memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED,
"running_var tensor must be interleaved");
}

validate_tensors(operation_attributes, tensor_args);
};

Expand Down Expand Up @@ -160,24 +123,12 @@ std::tuple<BatchNormOperation::operation_attributes_t, BatchNormOperation::tenso
const Tensor& batch_mean,
const Tensor& batch_var,
const float eps,
const float momentum,
const bool training,
std::optional<Tensor> weight,
std::optional<Tensor> bias,
std::optional<Tensor> running_mean,
std::optional<Tensor> running_var,
std::optional<Tensor> output,
const std::optional<MemoryConfig>& memory_config) {
operation_attributes_t operation_attributes{eps, momentum, training, memory_config.value_or(input.memory_config())};
tensor_args_t tensor_args{
input,
batch_mean,
batch_var,
std::move(weight),
std::move(bias),
std::move(running_mean),
std::move(running_var),
std::move(output)};
operation_attributes_t operation_attributes{eps, memory_config.value_or(input.memory_config())};
tensor_args_t tensor_args{input, batch_mean, batch_var, std::move(weight), std::move(bias), std::move(output)};
return {operation_attributes, tensor_args};
}
} // namespace ttnn::operations::normalization
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ namespace ttnn::operations::normalization {
struct BatchNormOperation {
struct operation_attributes_t {
const float eps;
const float momentum;
const bool training;
const MemoryConfig memory_config;

DataType input_dtype;
Expand All @@ -27,8 +25,6 @@ struct BatchNormOperation {
const Tensor& batch_var;
std::optional<Tensor> weight;
std::optional<Tensor> bias;
std::optional<Tensor> running_mean;
std::optional<Tensor> running_var;
std::optional<Tensor> output;
};

Expand Down Expand Up @@ -70,12 +66,8 @@ struct BatchNormOperation {
const Tensor& batch_mean,
const Tensor& batch_var,
const float eps,
const float momentum,
const bool training,
std::optional<Tensor> weight,
std::optional<Tensor> bias,
std::optional<Tensor> running_mean,
std::optional<Tensor> running_var,
std::optional<Tensor> output,
const std::optional<MemoryConfig>& memory_config);
};
Expand Down
Loading

0 comments on commit 29b1fed

Please sign in to comment.