diff --git a/tests/ttnn/unit_tests/operations/test_batch_norm.py b/tests/ttnn/unit_tests/operations/test_batch_norm.py
index 2e73a8ae0562..459ba02ea17e 100644
--- a/tests/ttnn/unit_tests/operations/test_batch_norm.py
+++ b/tests/ttnn/unit_tests/operations/test_batch_norm.py
@@ -24,31 +24,50 @@
         torch.Size([3, 2, 64, 120]),
     ],
 )
-@pytest.mark.parametrize("training", [False])
+@pytest.mark.parametrize(
+    "training, check_mean, check_var",
+    [
+        # (True, True, True),
+        # (True, True, False),
+        # (True, False, True),
+        (True, False, False),
+        (False, False, False),  # xfail case
+        (False, True, False),  # xfail case
+        (False, False, True),  # xfail case
+        (False, True, True),
+    ],
+)
 @pytest.mark.parametrize("weight", [True, False])
 @pytest.mark.parametrize("bias", [True, False])
 @pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05])
-def test_batch_norm(input_shapes, training, weight, bias, eps, device):
+@pytest.mark.parametrize("momentum", [0.1, 0.0, 2.3])
+def test_batch_norm(input_shapes, training, check_mean, check_var, weight, bias, eps, momentum, device):
     in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True)
     mean_data, mean_tensor = (
-        data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if (not training) else (None, None)
-    )
-    var_data, var_tensor = (
-        data_gen_with_range_batch_norm(input_shapes, 4, 20, device) if (not training) else (None, None)
+        data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if (check_mean) else (None, None)
     )
+    var_data, var_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 20, device) if (check_var) else (None, None)
     weight_data, weight_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if weight else (None, None)
     bias_data, bias_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if bias else (None, None)
 
+    if (not check_mean) or (not check_var):
+        pytest.xfail("running_mean and running_var must be defined in evaluation mode")
+
     tt_output_tensor_on_device = ttnn.batch_norm(
         input_tensor,
         running_mean=mean_tensor,
         running_var=var_tensor,
         training=training,
         eps=eps,
+        momentum=momentum,
         weight=weight_tensor,
         bias=bias_tensor,
     )
     tt_output = ttnn.to_torch(tt_output_tensor_on_device)
+
+    # tt_updated_mean = ttnn.to_torch(mean_tensor)
+    # tt_updated_var = ttnn.to_torch(var_tensor)
+
     # ttnn.set_printoptions(profile="full")
     # print("TT result : ", tt_output, tt_output.shape)
     # torch.set_printoptions(precision=5, sci_mode=False)
@@ -60,9 +79,15 @@ def test_batch_norm(input_shapes, training, weight, bias, eps, device):
         bias=bias_data,
         training=training,
         eps=eps,
+        momentum=momentum,
     )
     # print("Torch result : ",torch_result)
-    comp_pass = compare_results_batch_norm([tt_output], [torch_result])
+    comp_pass = compare_results_batch_norm([tt_output], [torch_result])  # Check BN Result
+    # if training :
+    #     channels = input_shapes[1]
+    #     comp_pass_1 = compare_results_batch_norm([tt_updated_mean], [mean_data.view(1, channels, 1, 1)]) # Check Updated running mean
+    #     comp_pass_2 = compare_results_batch_norm([tt_updated_var], [var_data.view(1, channels, 1, 1)])  # Check Updated running var
+    #     comp_pass = comp_pass and comp_pass_1 and comp_pass_2
     assert comp_pass
 
 
diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp
index 5bcc1ec44861..dba53e2c5ac6 100644
--- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp
+++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm.cpp
@@ -5,11 +5,27 @@
 #include "batch_norm.hpp"
 
 #include "device/batch_norm_device_operation.hpp"
+#include "ttnn/operations/moreh/moreh_mean/device/moreh_mean_device_operation.hpp"
+#include "ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp"
 
 using namespace tt::tt_metal;
 
 namespace ttnn::operations::normalization {
 
+inline Tensor mean_NHW(const Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config) {
+    auto output_memory_config = memory_config.value_or(input_tensor.memory_config());
+    auto batch_mean = input_tensor;
+    ttnn::SmallVector<int64_t> dims = {0, 2, 3};
+    std::sort(dims.begin(), dims.end());
+    for (uint32_t i = dims.size() - 1; i > 0; i--) {
+        auto temp_output = ttnn::prim::moreh_mean(
+            batch_mean, dims[i], true, std::nullopt, std::nullopt, output_memory_config, std::nullopt);
+        batch_mean = temp_output;
+    }
+    return ttnn::prim::moreh_mean(
+        batch_mean, dims.front(), true, std::nullopt, std::nullopt, output_memory_config, std::nullopt);
+}
+
 Tensor BatchNorm::invoke(
     const Tensor& input,
     std::optional<Tensor> running_mean,
@@ -21,10 +37,40 @@ Tensor BatchNorm::invoke(
     std::optional<Tensor> bias,
     std::optional<Tensor> output,
     const std::optional<MemoryConfig>& memory_config) {
+    if (training) {
+        Tensor batch_mean = mean_NHW(input, memory_config);
+        Tensor mean_sq = mean_NHW(ttnn::square(input, memory_config), memory_config);
+        Tensor batch_var =
+            ttnn::subtract(mean_sq, ttnn::square(batch_mean, memory_config), std::nullopt, memory_config);
+        return ttnn::prim::batch_norm(
+            input,
+            batch_mean,
+            batch_var,
+            eps,
+            momentum,
+            training,
+            weight,
+            bias,
+            running_mean,
+            running_var,
+            output,
+            memory_config);
+    }
     TT_FATAL(
-        (running_mean.has_value() && running_var.has_value() && (!training)),
+        (running_mean.has_value() && running_var.has_value()),
         "running_mean and running_var must be defined in evaluation mode");
     return ttnn::prim::batch_norm(
-        input, running_mean.value(), running_var.value(), eps, momentum, training, weight, bias, output, memory_config);
+        input,
+        running_mean.value(),
+        running_var.value(),
+        eps,
+        momentum,
+        training,
+        weight,
+        bias,
+        std::nullopt,
+        std::nullopt,
+        output,
+        memory_config);
 }
 }  // namespace ttnn::operations::normalization
diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp
index c28165ef8c4e..92973c9da986 100644
--- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp
+++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp
@@ -14,7 +14,7 @@ void bind_batch_norm_operation(pybind11::module& module) {
         module,
         ttnn::batch_norm,
         R"doc(
-            Applies Spatial Batch Normalization over each channel on :attr:`input_tensor`. Inputs must be must be tilized and interleaved. Currently support is provided for inference mode only.
+            Applies Spatial Batch Normalization over each channel on :attr:`input_tensor`. Inputs must be must be tilized and interleaved.
 
 
         Args:
@@ -24,8 +24,8 @@ void bind_batch_norm_operation(pybind11::module& module) {
         Keyword args:
             eps (float, optional): Epsilon value. Defaults to `1e-05`.
             momentum (float, optional): Momentum value. Defaults to `0.1`.
-            running_mean (ttnn.Tensor, optional): the running_mean of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`.
-            running_var (ttnn.Tensor, optional): the running_var of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`.
+            running_mean (ttnn.Tensor, optional): the running_mean of shape `[1, C, 1, 1]`, required in inference mode. When in training mode, this tensor is optional and the updated running mean value is stored in-place based on the inputs provided. Defaults to `None`.
+            running_var (ttnn.Tensor, optional): the running_var of shape `[1, C, 1, 1]`, required in inference mode. When in training mode, this tensor is optional and the updated running variance value is stored in-place based on the inputs provided. Defaults to `None`.
             weight (ttnn.Tensor, optional): the weight or gamma value of shape `[1, C, 1, 1]`. Defaults to `None`.
             bias (ttnn.Tensor, optional): the bias or beta value of shape `[1, C, 1, 1]`. Defaults to `None`.
             training (bool, optional): Selection between training mode and inference (evaluation) mode. Defaults to `False` (Inference mode).
diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp
index 87caa2213397..9673493a5e52 100644
--- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp
+++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp
@@ -10,7 +10,7 @@
 namespace ttnn::operations::normalization {
 void BatchNormOperation::validate_tensors(
     const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
-    const auto& [input, batch_mean, batch_var, weight, bias, output] = tensor_args;
+    const auto& [input, batch_mean, batch_var, weight, bias, running_mean, running_var, output] = tensor_args;
 
     check_tensor(input, "batch_norm", "input");
     check_tensor(batch_mean, "batch_norm", "batch_mean");
@@ -18,6 +18,8 @@ void BatchNormOperation::validate_tensors(
     check_tensor(weight, "batch_norm", "weight");
     check_tensor(bias, "batch_norm", "bias");
     check_tensor(output, "batch_norm", "output");
+    check_tensor(running_mean, "batch_norm", "running_mean");
+    check_tensor(running_var, "batch_norm", "running_var");
 
     // input (N, C, H, W)
     auto C = input.get_logical_shape()[1];
@@ -45,6 +47,26 @@ void BatchNormOperation::validate_tensors(
         TT_FATAL(bias.value().get_logical_shape()[1] == C, "bias_shape[1] must be the same as input's channel size.");
         TT_FATAL(bias.value().get_logical_shape()[1] == C, "bias_shape[1] must be the same as input's channel size.");
     }
+
+    // running_mean (1, C, 1, 1)
+    if (running_mean.has_value()) {
+        TT_FATAL(
+            running_mean.value().get_logical_shape()[1] == C,
+            "running_mean_shape[1] must be the same as input's channel size.");
+        TT_FATAL(
+            running_mean.value().get_logical_shape()[1] == C,
+            "running_mean_shape[1] must be the same as input's channel size.");
+    }
+
+    // running_var (1, C, 1, 1)
+    if (running_var.has_value()) {
+        TT_FATAL(
+            running_var.value().get_logical_shape()[1] == C,
+            "running_var_shape[1] must be the same as input's channel size.");
+        TT_FATAL(
+            running_var.value().get_logical_shape()[1] == C,
+            "running_var_shape[1] must be the same as input's channel size.");
+    }
 }
 
 BatchNormOperation::program_factory_t BatchNormOperation::select_program_factory(
@@ -54,7 +76,8 @@ BatchNormOperation::program_factory_t BatchNormOperation::select_program_factory
 
 void BatchNormOperation::validate_on_program_cache_miss(
     const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
-    const auto& [input, batch_mean, batch_var, weight, bias, output] = tensor_args;
+
+    const auto& [input, batch_mean, batch_var, weight, bias, running_mean, running_var, output] = tensor_args;
 
     TT_FATAL(input.get_layout() == Layout::TILE, "Input tensor must be must be tilized");
     TT_FATAL(
@@ -87,6 +110,20 @@ void BatchNormOperation::validate_on_program_cache_miss(
             "bias tensor must be interleaved");
     }
 
+    if (running_mean.has_value()) {
+        TT_FATAL(running_mean.value().get_layout() == Layout::TILE, "running_mean tensor must be tilized");
+        TT_FATAL(
+            running_mean.value().memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED,
+            "running_mean tensor must be interleaved");
+    }
+
+    if (running_var.has_value()) {
+        TT_FATAL(running_var.value().get_layout() == Layout::TILE, "running_var tensor must be tilized");
+        TT_FATAL(
+            running_var.value().memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED,
+            "running_var tensor must be interleaved");
+    }
+
     validate_tensors(operation_attributes, tensor_args);
 };
 
@@ -127,10 +164,20 @@ std::tuple<BatchNormOperation::operation_attributes_t, BatchNormOperation::tenso
     const bool training,
     std::optional<Tensor> weight,
     std::optional<Tensor> bias,
+    std::optional<Tensor> running_mean,
+    std::optional<Tensor> running_var,
     std::optional<Tensor> output,
     const std::optional<MemoryConfig>& memory_config) {
     operation_attributes_t operation_attributes{eps, momentum, training, memory_config.value_or(input.memory_config())};
-    tensor_args_t tensor_args{input, batch_mean, batch_var, std::move(weight), std::move(bias), std::move(output)};
+    tensor_args_t tensor_args{
+        input,
+        batch_mean,
+        batch_var,
+        std::move(weight),
+        std::move(bias),
+        std::move(running_mean),
+        std::move(running_var),
+        std::move(output)};
     return {operation_attributes, tensor_args};
 }
 }  // namespace ttnn::operations::normalization
diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.hpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.hpp
index 985634f6dfdb..d9d848c7564a 100644
--- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.hpp
+++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.hpp
@@ -27,6 +27,8 @@ struct BatchNormOperation {
         const Tensor& batch_var;
         std::optional<Tensor> weight;
         std::optional<Tensor> bias;
+        std::optional<Tensor> running_mean;
+        std::optional<Tensor> running_var;
         std::optional<Tensor> output;
     };
 
@@ -72,6 +74,8 @@ struct BatchNormOperation {
         const bool training,
         std::optional<Tensor> weight,
         std::optional<Tensor> bias,
+        std::optional<Tensor> running_mean,
+        std::optional<Tensor> running_var,
         std::optional<Tensor> output,
         const std::optional<MemoryConfig>& memory_config);
 };
diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp
index 9391a5c99447..1fcc09197362 100644
--- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp
+++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp
@@ -29,12 +29,15 @@ void set_or_update_runtime_arguments(
     const BatchNormOperation::tensor_args_t& tensor_args,
     BatchNormOperation::tensor_return_value_t& c,
     F handle_args) {
-    const auto& [a, b, d, e, f, _] = tensor_args;
+    const auto& [a, b, d, e, f, g, h, _] = tensor_args;
     const auto eps = operation_attributes.eps;
     const auto momentum = operation_attributes.momentum;
 
     const bool weight_has_value = e.has_value();
     const bool bias_has_value = f.has_value();
+    const bool running_mean_has_value = g.has_value();
+    const bool running_var_has_value = h.has_value();
+    const bool is_training_mode = operation_attributes.training;
 
     const auto ashape = a.padded_shape();
     const auto bshape = b.padded_shape();
@@ -65,7 +68,7 @@ void set_or_update_runtime_arguments(
             num_tiles_per_core = num_tiles_per_core_group_2;
         } else {
             handle_args(program, reader_kernel_id, core, std::array<uint32_t, 12>{0});
-            handle_args(program, writer_kernel_id, core, std::array<uint32_t, 14>{0});
+            handle_args(program, writer_kernel_id, core, std::array<uint32_t, 16>{0});
             handle_args(program, compute_kernel_id, core, std::array<uint32_t, 3>{0});
             continue;
         }
@@ -93,11 +96,15 @@ void set_or_update_runtime_arguments(
 
         const auto weight_addr = weight_has_value ? e->buffer()->address() : 0;
         const auto bias_addr = bias_has_value ? f->buffer()->address() : 0;
+        const auto running_mean_addr = is_training_mode and running_mean_has_value ? g->buffer()->address() : 0;
+        const auto running_var_addr = is_training_mode and running_var_has_value ? h->buffer()->address() : 0;
         std::array writer_runtime_args = {
             b.buffer()->address(),  //  batch mean
             d.buffer()->address(),  //  batch var
             weight_addr,            // weight
             bias_addr,              // bias
+            running_mean_addr,      // old running mean
+            running_var_addr,       // old running var
             c.buffer()->address(),  // output
             start_tile_id,
             num_tiles_per_core,
@@ -131,7 +138,7 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
     using namespace tt;
     using namespace tt::tt_metal;
 
-    const auto& [a, b, d, e, f, _] = tensor_args;
+    const auto& [a, b, d, e, f, g, h, _] = tensor_args;
 
     auto program = CreateProgram();
 
@@ -139,6 +146,9 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
 
     const bool weight_has_value = e.has_value();
     const bool bias_has_value = f.has_value();
+    const bool running_mean_has_value = g.has_value();
+    const bool running_var_has_value = h.has_value();
+    const bool is_training_mode = operation_attributes.training;
 
     auto a_data_format = datatype_to_dataformat_converter(a.get_dtype());
     auto b_data_format = datatype_to_dataformat_converter(b.get_dtype());
@@ -146,6 +156,10 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
     auto d_data_format = datatype_to_dataformat_converter(d.get_dtype());
     auto e_data_format = weight_has_value ? datatype_to_dataformat_converter(e->get_dtype()) : DataFormat::Float16_b;
     auto f_data_format = bias_has_value ? datatype_to_dataformat_converter(f->get_dtype()) : DataFormat::Float16_b;
+    auto g_data_format = is_training_mode and running_mean_has_value ? datatype_to_dataformat_converter(g->get_dtype())
+                                                                     : DataFormat::Float16_b;
+    auto h_data_format = is_training_mode and running_var_has_value ? datatype_to_dataformat_converter(h->get_dtype())
+                                                                    : DataFormat::Float16_b;
 
     uint32_t a_single_tile_size = tt_metal::detail::TileSize(a_data_format);
     uint32_t b_single_tile_size = tt_metal::detail::TileSize(b_data_format);
@@ -153,6 +167,8 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
     uint32_t d_single_tile_size = tt_metal::detail::TileSize(d_data_format);
     uint32_t e_single_tile_size = tt_metal::detail::TileSize(e_data_format);
     uint32_t f_single_tile_size = tt_metal::detail::TileSize(f_data_format);
+    uint32_t g_single_tile_size = tt_metal::detail::TileSize(g_data_format);
+    uint32_t h_single_tile_size = tt_metal::detail::TileSize(h_data_format);
 
     uint32_t num_output_tiles = output.volume() / output.tensor_spec().tile().get_tile_hw();
 
@@ -199,6 +215,20 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
         d_single_tile_size,
         b_num_tiles_per_cb,
         d_data_format);  // momentum
+    auto [g_cb, g_cb_handle] = create_cb(
+        tt::CBIndex::c_25,
+        program,
+        all_device_cores,
+        g_single_tile_size,
+        b_num_tiles_per_cb,
+        g_data_format);  // old running mean
+    auto [h_cb, h_cb_handle] = create_cb(
+        tt::CBIndex::c_26,
+        program,
+        all_device_cores,
+        h_single_tile_size,
+        b_num_tiles_per_cb,
+        h_data_format);  // old running var
 
     // Temporary buffers to store intermediate results
     auto [den_cb, den_cb_handle] = create_cb(
@@ -217,6 +247,20 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
         a_data_format);  // to store input - batch_mean
     auto [temp_1_cb, temp_1_cb_handle] =
         create_cb(tt::CBIndex::c_17, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format);
+    auto [updated_m_cb, updated_m_cb_handle] = create_cb(
+        tt::CBIndex::c_27,
+        program,
+        all_device_cores,
+        g_single_tile_size,
+        b_num_tiles_per_cb,
+        g_data_format);  // updated running mean
+    auto [updated_v_cb, updated_v_cb_handle] = create_cb(
+        tt::CBIndex::c_28,
+        program,
+        all_device_cores,
+        h_single_tile_size,
+        b_num_tiles_per_cb,
+        h_data_format);  // updated running var
 
     auto a_is_dram = static_cast<uint32_t>(a.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
     auto b_is_dram = static_cast<uint32_t>(b.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
@@ -224,6 +268,10 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
     auto d_is_dram = static_cast<uint32_t>(d.buffer()->buffer_type() == tt_metal::BufferType::DRAM);
     const auto e_is_dram = weight_has_value and e->buffer()->buffer_type() == tt_metal::BufferType::DRAM;
     const auto f_is_dram = bias_has_value and f->buffer()->buffer_type() == tt_metal::BufferType::DRAM;
+    const auto g_is_dram =
+        is_training_mode and running_mean_has_value and g->buffer()->buffer_type() == tt_metal::BufferType::DRAM;
+    const auto h_is_dram =
+        is_training_mode and running_var_has_value and h->buffer()->buffer_type() == tt_metal::BufferType::DRAM;
 
     // READER KERNEL
     auto reader_kernel_id = tt_metal::CreateKernel(
@@ -237,15 +285,20 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
         program,
         "ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp",
         all_device_cores,
-        tt_metal::WriterDataMovementConfig(
-            {b_is_dram,
-             c_is_dram,
-             d_is_dram,
-             e_is_dram,
-             f_is_dram,
-             static_cast<uint32_t>(weight_has_value),
-             static_cast<uint32_t>(bias_has_value),
-             static_cast<uint32_t>(operation_attributes.training)}));
+        tt_metal::WriterDataMovementConfig({
+            b_is_dram,
+            c_is_dram,
+            d_is_dram,
+            e_is_dram,
+            f_is_dram,
+            static_cast<uint32_t>(weight_has_value),
+            static_cast<uint32_t>(bias_has_value),
+            static_cast<uint32_t>(operation_attributes.training),
+            g_is_dram,
+            h_is_dram,
+            static_cast<uint32_t>(running_mean_has_value),
+            static_cast<uint32_t>(running_var_has_value),
+        }));
 
     // COMPUTE KERNEL
     bool fp32_dest_acc_en = c_data_format == tt::DataFormat::UInt32 || c_data_format == tt::DataFormat::Int32 ||
@@ -253,7 +306,9 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch
     std::vector<uint32_t> compute_kernel_args = {
         static_cast<uint32_t>(weight_has_value),
         static_cast<uint32_t>(bias_has_value),
-        static_cast<uint32_t>(operation_attributes.training)};
+        static_cast<uint32_t>(operation_attributes.training),
+        static_cast<uint32_t>(running_mean_has_value),
+        static_cast<uint32_t>(running_var_has_value)};
     auto compute_kernel_id = tt_metal::CreateKernel(
         program,
         "ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp",
diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp
index b14d3ec5bfcf..e1a99a8ecaaa 100644
--- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp
+++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp
@@ -40,6 +40,8 @@ void MAIN {
     constexpr uint32_t weight_has_value = get_compile_time_arg_val(0) == 1;
     constexpr uint32_t bias_has_value = get_compile_time_arg_val(1) == 1;
     constexpr uint32_t is_training_mode = get_compile_time_arg_val(2) == 1;
+    constexpr uint32_t old_running_mean_has_value = get_compile_time_arg_val(3) == 1;
+    constexpr uint32_t old_running_var_has_value = get_compile_time_arg_val(4) == 1;
 
     if (num_tiles == 0) {
         return;
@@ -56,6 +58,11 @@ void MAIN {
     constexpr auto cb_weight = tt::CBIndex::c_16;    // weight tensor
     constexpr auto cb_tmp_1 = tt::CBIndex::c_17;     // (input - batch_mean)/(sqrt(batch_var + eps))
     constexpr auto cb_bias = tt::CBIndex::c_18;      // bias tensor
+    constexpr auto cb_old_running_mean = tt::CBIndex::c_25;  // old running mean tensor
+    constexpr auto cb_old_running_var = tt::CBIndex::c_26;   // old running var tensor
+    constexpr auto cb_updated_running_mean = tt::CBIndex::c_27;  // updated running mean tensor
+    constexpr auto cb_updated_running_var = tt::CBIndex::c_28;   // updated running var tensor
+    constexpr auto cb_momentum = tt::CBIndex::c_24;              // momentum
 
     auto cb_bcast = cb_batch_mean;
     auto cb_other = cb_input;
@@ -118,7 +125,12 @@ void MAIN {
         cb_push_back(cb_affine_or_out, onetile);
 
         if constexpr (is_training_mode) {
-            // update running stats here
+            // updated running stats
+            if constexpr (old_running_mean_has_value) {
+            }
+
+            if constexpr (old_running_var_has_value) {
+            }
         }
 
         if constexpr (weight_has_value) {  // result = result * weight
diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp
index 7fdc32b5339d..13c9cf2c3d5a 100644
--- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp
+++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp
@@ -12,14 +12,16 @@ void kernel_main() {
     uint32_t batch_var_addr = get_arg_val<uint32_t>(1);  // batch_var
     uint32_t weight_addr = get_arg_val<uint32_t>(2);     // weight
     uint32_t bias_addr = get_arg_val<uint32_t>(3);       // bias
-    uint32_t dst_addr = get_arg_val<uint32_t>(4);        // output
-    uint32_t start_tile_id = get_arg_val<uint32_t>(5);
-    uint32_t num_tiles = get_arg_val<uint32_t>(6);
-    uint32_t HtWt = get_arg_val<uint32_t>(7);
-    uint32_t n_stride = get_arg_val<uint32_t>(8);
-    uint32_t c_stride = get_arg_val<uint32_t>(9);
-    uint32_t N = get_arg_val<uint32_t>(10);
-    uint32_t C = get_arg_val<uint32_t>(11);
+    uint32_t old_running_mean_addr = get_arg_val<uint32_t>(4);  // old running_mean
+    uint32_t old_running_var_addr = get_arg_val<uint32_t>(5);   // ols running_var
+    uint32_t dst_addr = get_arg_val<uint32_t>(6);               // output
+    uint32_t start_tile_id = get_arg_val<uint32_t>(7);
+    uint32_t num_tiles = get_arg_val<uint32_t>(8);
+    uint32_t HtWt = get_arg_val<uint32_t>(9);
+    uint32_t n_stride = get_arg_val<uint32_t>(10);
+    uint32_t c_stride = get_arg_val<uint32_t>(11);
+    uint32_t N = get_arg_val<uint32_t>(12);
+    uint32_t C = get_arg_val<uint32_t>(13);
 
     constexpr uint32_t onetile = 1;
 
@@ -72,6 +74,33 @@ void kernel_main() {
     constexpr bool bias_has_value = get_compile_time_arg_val(6) == 1;
     constexpr bool is_training_mode = get_compile_time_arg_val(7) == 1;
 
+    // old running mean
+    constexpr auto cb_id_old_running_mean = tt::CBIndex::c_25;
+    constexpr bool old_running_mean_is_dram = get_compile_time_arg_val(8) == 1;
+    const uint32_t old_running_mean_tile_bytes = get_tile_size(cb_id_old_running_mean);
+    const DataFormat old_running_mean_data_format = get_dataformat(cb_id_old_running_mean);
+
+    const InterleavedAddrGenFast<old_running_mean_is_dram> old_running_mean = {
+        .bank_base_address = old_running_mean_addr,
+        .page_size = old_running_mean_tile_bytes,
+        .data_format = old_running_mean_data_format};
+
+    // old running var
+    constexpr auto cb_id_old_running_var = tt::CBIndex::c_26;
+    constexpr bool old_running_var_is_dram = get_compile_time_arg_val(9) == 1;
+    const uint32_t old_running_var_tile_bytes = get_tile_size(cb_id_old_running_var);
+    const DataFormat old_running_var_data_format = get_dataformat(cb_id_old_running_var);
+
+    const InterleavedAddrGenFast<old_running_var_is_dram> old_running_var = {
+        .bank_base_address = old_running_var_addr,
+        .page_size = old_running_var_tile_bytes,
+        .data_format = old_running_var_data_format};
+
+    constexpr bool old_running_mean_has_value = get_compile_time_arg_val(10) == 1;
+    constexpr bool old_running_var_has_value = get_compile_time_arg_val(11) == 1;
+    constexpr auto cb_id_updated_running_mean = tt::CBIndex::c_27;
+    constexpr auto cb_id_updated_running_var = tt::CBIndex::c_28;
+
     uint32_t tiles_per_batch = HtWt * C;
     uint32_t start_n = start_tile_id / tiles_per_batch;
     uint32_t start_remaining = start_tile_id % tiles_per_batch;
@@ -121,6 +150,39 @@ void kernel_main() {
 
             // to read running stats value for updation
             if constexpr (is_training_mode) {
+                if constexpr (old_running_mean_has_value) {
+                    // read data
+                    cb_reserve_back(cb_id_old_running_mean, onetile);
+                    uint32_t l1_old_running_mean_write_addr = get_write_ptr(cb_id_old_running_mean);
+                    noc_async_read_tile(tile_offset, old_running_mean, l1_old_running_mean_write_addr);
+                    noc_async_read_barrier();
+                    fill_tile_with_first_element_bfloat16(cb_id_old_running_mean);
+                    cb_push_back(cb_id_old_running_mean, onetile);
+
+                    // write data
+                    cb_wait_front(cb_id_updated_running_mean, onetile);
+                    uint32_t l1_write_updated_mean_addr = get_read_ptr(cb_id_updated_running_mean);
+                    noc_async_write_tile(tile_offset, old_running_mean, l1_write_updated_mean_addr);
+                    noc_async_write_barrier();
+                    cb_pop_front(cb_id_updated_running_mean, onetile);
+                }
+
+                if constexpr (old_running_var_has_value) {
+                    // read data
+                    cb_reserve_back(cb_id_old_running_var, onetile);
+                    uint32_t l1_old_running_var_write_addr = get_write_ptr(cb_id_old_running_var);
+                    noc_async_read_tile(tile_offset, old_running_var, l1_old_running_var_write_addr);
+                    noc_async_read_barrier();
+                    fill_tile_with_first_element_bfloat16(cb_id_old_running_var);
+                    cb_push_back(cb_id_old_running_var, onetile);
+
+                    // write data
+                    cb_wait_front(cb_id_updated_running_var, onetile);
+                    uint32_t l1_write_updated_var_addr = get_read_ptr(cb_id_updated_running_var);
+                    noc_async_write_tile(tile_offset, old_running_var, l1_write_updated_var_addr);
+                    noc_async_write_barrier();
+                    cb_pop_front(cb_id_updated_running_var, onetile);
+                }
             }
 
             for (uint32_t t = start_t; t < HtWt && num_tiles_written < num_tiles; ++t, ++num_tiles_written) {