From 052177f344c6cfb093522ea31453be68b06c1068 Mon Sep 17 00:00:00 2001 From: VirdhatchaniKN Date: Tue, 7 Jan 2025 18:47:08 +0000 Subject: [PATCH] #12253: Update files --- .../eltwise/backward/utility_funcs.py | 35 +++--- .../unit_tests/operations/test_batch_norm.py | 106 +++++++++--------- .../batch_norm/batch_norm_pybind.cpp | 13 ++- .../device/batch_norm_device_operation.cpp | 21 +--- .../device/batch_norm_program_factory.cpp | 63 ++++------- .../kernels/compute/batch_norm_kernel.cpp | 4 +- .../kernels/dataflow/writer_batch_norm.cpp | 29 ++--- 7 files changed, 119 insertions(+), 152 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/eltwise/backward/utility_funcs.py b/tests/ttnn/unit_tests/operations/eltwise/backward/utility_funcs.py index 52143724764..6319e4b4f37 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/backward/utility_funcs.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/utility_funcs.py @@ -11,24 +11,29 @@ ) -def data_gen_with_range_batch_norm(input_shapes, low, high, device, is_input=False, required_grad=False): +def data_gen_with_range_batch_norm( + input_shapes, + low, + high, + device, + is_input=False, + required_grad=False, +): assert high > low, "Incorrect range provided" torch.manual_seed(213919) channels = input_shapes[1] - if is_input: - pt_tensor = torch.rand(input_shapes, requires_grad=required_grad).bfloat16() * (high - low) + low - tt_tensor = ttnn.from_torch( - pt_tensor, - device=device, - layout=ttnn.TILE_LAYOUT, - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - ) - else: - pt_tensor = torch.rand(channels, requires_grad=required_grad).bfloat16() * (high - low) + low - # pt_tensor = pt_tensor.view(1, channels, 1, 1) # to test each section of TT op - reshaped_tensor = pt_tensor.view(1, channels, 1, 1).expand(1, channels, 32, 32) - tt_tensor = ttnn.Tensor(reshaped_tensor, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) + size = input_shapes if is_input else channels + pt_tensor = torch.rand(size, requires_grad=required_grad).bfloat16() * (high - low) + low + reshaped_tensor = pt_tensor + if not is_input: + reshaped_tensor = pt_tensor.view(1, channels, 1, 1) + tt_tensor = ttnn.from_torch( + reshaped_tensor, + device=device, + layout=ttnn.TILE_LAYOUT, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) return pt_tensor, tt_tensor diff --git a/tests/ttnn/unit_tests/operations/test_batch_norm.py b/tests/ttnn/unit_tests/operations/test_batch_norm.py index 42ea7d8f979..5b2287201a2 100644 --- a/tests/ttnn/unit_tests/operations/test_batch_norm.py +++ b/tests/ttnn/unit_tests/operations/test_batch_norm.py @@ -9,71 +9,35 @@ data_gen_with_range_batch_norm, compare_results_batch_norm, ) +from itertools import product @pytest.mark.parametrize( "input_shapes", - ( - (torch.Size([1, 1, 32, 32])), - (torch.Size([1, 2, 32, 32])), - (torch.Size([1, 3, 32, 32])), - (torch.Size([2, 1, 32, 32])), - (torch.Size([2, 2, 32, 32])), - (torch.Size([2, 3, 32, 32])), - (torch.Size([3, 1, 32, 32])), - (torch.Size([3, 2, 32, 32])), - (torch.Size([3, 3, 32, 32])), - (torch.Size([4, 1, 32, 32])), - (torch.Size([4, 2, 32, 32])), - (torch.Size([4, 3, 32, 32])), - (torch.Size([4, 4, 32, 32])), - (torch.Size([1, 1, 23, 23])), - (torch.Size([1, 2, 23, 23])), - (torch.Size([1, 3, 23, 23])), - (torch.Size([2, 1, 23, 23])), - (torch.Size([2, 2, 23, 23])), - (torch.Size([2, 3, 23, 23])), - (torch.Size([3, 1, 23, 23])), - (torch.Size([3, 2, 23, 23])), - (torch.Size([3, 3, 23, 23])), - (torch.Size([4, 1, 23, 23])), - (torch.Size([4, 2, 23, 23])), - (torch.Size([4, 3, 23, 23])), - (torch.Size([4, 4, 23, 23])), - (torch.Size([1, 1, 64, 120])), - (torch.Size([1, 2, 64, 120])), - (torch.Size([1, 3, 64, 120])), - (torch.Size([2, 1, 64, 120])), - (torch.Size([2, 2, 64, 120])), - (torch.Size([2, 3, 64, 120])), - (torch.Size([3, 1, 64, 120])), - (torch.Size([3, 2, 64, 120])), - ), + [ + *(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4], [1, 2, 3])), + torch.Size([4, 4, 32, 32]), + *(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4], [1, 2, 3])), + torch.Size([4, 4, 23, 23]), + *(torch.Size([n, c, 64, 120]) for n, c in product([1, 2], [1, 2, 3])), + torch.Size([3, 1, 64, 120]), + torch.Size([3, 2, 64, 120]), + ], ) @pytest.mark.parametrize("training", [False]) @pytest.mark.parametrize("weight", [True, False]) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("eps", [1.0, 0.0, 2.34, 1e-05]) def test_batch_norm(input_shapes, training, weight, bias, eps, device): - in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, True, False) - if not training: - mean_data, mean_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device, False, False) - var_data, var_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 20, device, False, False) - else: - mean_data = None - mean_tensor = None - var_data = None - var_tensor = None - if weight: - weight_data, weight_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device, False, False) - else: - weight_data = None - weight_tensor = None - if bias: - bias_data, bias_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device, False, False) - else: - bias_data = None - bias_tensor = None + in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True) + mean_data, mean_tensor = ( + data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if (not training) else (None, None) + ) + var_data, var_tensor = ( + data_gen_with_range_batch_norm(input_shapes, 4, 20, device) if (not training) else (None, None) + ) + weight_data, weight_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if weight else (None, None) + bias_data, bias_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) if bias else (None, None) tt_output_tensor_on_device = ttnn.batch_norm( input_tensor, @@ -100,3 +64,35 @@ def test_batch_norm(input_shapes, training, weight, bias, eps, device): # print("Torch result : ",torch_result) comp_pass = compare_results_batch_norm([tt_output], [torch_result]) assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + [ + torch.Size([3, 2, 32, 32]), + ], +) +@pytest.mark.parametrize("mem_layout", [ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.TensorMemoryLayout.HEIGHT_SHARDED]) +def test_batch_norm_program_cache_and_default(input_shapes, mem_layout, device): + N, H, W, C = input_shapes + in_data, input_tensor = data_gen_with_range_batch_norm(input_shapes, 5, 10, device, is_input=True) + mean_data, mean_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 10, device) + var_data, var_tensor = data_gen_with_range_batch_norm(input_shapes, 4, 20, device) + + grid_size = ttnn.CoreGrid(y=1, x=8) + grid_coord = ttnn.CoreCoord(grid_size.x - 1, grid_size.y - 1) + shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)}) + shard_shape = N * H * W // grid_size.x, C // grid_size.y + shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.COL_MAJOR, False) + sharded_mem_config = ttnn.MemoryConfig(mem_layout, ttnn.types.BufferType.L1, shard_spec) + + if mem_layout is not ttnn.TensorMemoryLayout.INTERLEAVED: + pytest.xfail("Input tensors to batch norm must be interleaved") + + tt_output_tensor_on_device = ttnn.batch_norm( + input_tensor, running_mean=mean_tensor, running_var=var_tensor, memory_config=sharded_mem_config + ) + tt_output = ttnn.to_torch(tt_output_tensor_on_device) + torch_result = torch.nn.functional.batch_norm(input=in_data, running_mean=mean_data, running_var=var_data) + comp_pass = compare_results_batch_norm([tt_output], [torch_result]) + assert comp_pass diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp index f0c159b61c4..5428b1aa4f3 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/batch_norm_pybind.cpp @@ -14,20 +14,21 @@ void bind_batch_norm_operation(pybind11::module& module) { module, ttnn::batch_norm, R"doc( - Applies Spatial Batch Normalization over each channel on :attr:`input_tensor`.Currently support is provided for inference mode only. + Applies Spatial Batch Normalization over each channel on :attr:`input_tensor`. Inputs must be must be tilized and interleaved. Currently support is provided for inference mode only. Args: - input_tensor (ttnn.Tensor): the input tensor. + input_tensor (ttnn.Tensor): the input tensor of shape `[N, C, H, W]`. Keyword args: eps (float, optional): Epsilon value. Defaults to `1e-05`. - running_mean (ttnn.Tensor, optional): the running_mean required for inference mode. Defaults to `None`. - running_var (ttnn.Tensor, optional): the running_var required for inference mode. Defaults to `None`. - weight (ttnn.Tensor, optional): the weight or gamma value. Defaults to `None`. - bias (ttnn.Tensor, optional): the bias or beta value. Defaults to `None`. + running_mean (ttnn.Tensor, optional): the running_mean of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`. + running_var (ttnn.Tensor, optional): the running_var of shape `[1, C, 1, 1]`, required in inference mode . Defaults to `None`. + weight (ttnn.Tensor, optional): the weight or gamma value of shape `[1, C, 1, 1]`. Defaults to `None`. + bias (ttnn.Tensor, optional): the bias or beta value of shape `[1, C, 1, 1]`. Defaults to `None`. training (bool, optional): Selection between training mode and inference (evaluation) mode. Defaults to `False` (Inference mode). + output (ttnn.Tensor, optional): Preallocated output tensor to store batch norm result of shape `[N, C, H, W]`. Defaults to `None`. memory_config (ttnn.MemoryConfig, optional): memory configuration for the operation. Defaults to `None`. diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp index 9528282d1a0..3756cf56164 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_device_operation.cpp @@ -10,14 +10,7 @@ namespace ttnn::operations::normalization { void BatchNormOperation::validate_tensors( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - const auto& input = tensor_args.input; - const auto& batch_mean = tensor_args.batch_mean; - const auto& batch_var = tensor_args.batch_var; - const auto& eps = operation_attributes.eps; - const auto& weight = tensor_args.weight; - const auto& bias = tensor_args.bias; - - auto& output = tensor_args.output; + const auto& [input, batch_mean, batch_var, weight, bias, output] = tensor_args; check_tensor(input, "batch_norm", "input"); check_tensor(batch_mean, "batch_norm", "batch_mean"); @@ -49,8 +42,8 @@ void BatchNormOperation::validate_tensors( // bias (1, C, 1, 1) if (bias.has_value()) { - TT_FATAL(bias.value().get_logical_shape()[1] == C, "weight_shape[1] must be the same as input's channel size."); - TT_FATAL(bias.value().get_logical_shape()[1] == C, "weight_shape[1] must be the same as input's channel size."); + TT_FATAL(bias.value().get_logical_shape()[1] == C, "bias_shape[1] must be the same as input's channel size."); + TT_FATAL(bias.value().get_logical_shape()[1] == C, "bias_shape[1] must be the same as input's channel size."); } } @@ -61,13 +54,7 @@ BatchNormOperation::program_factory_t BatchNormOperation::select_program_factory void BatchNormOperation::validate_on_program_cache_miss( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - // We don't support sharding for now - const auto& input = tensor_args.input; - const auto& batch_mean = tensor_args.batch_mean; - const auto& batch_var = tensor_args.batch_var; - const auto& weight = tensor_args.weight; - const auto& bias = tensor_args.bias; - const auto& output = tensor_args.output; + const auto& [input, batch_mean, batch_var, weight, bias, output] = tensor_args; TT_FATAL(input.get_layout() == Layout::TILE, "Input tensor must be must be tilized"); TT_FATAL( diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp index 0c9850c5dcc..f52a2192b9f 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/batch_norm_program_factory.cpp @@ -29,11 +29,7 @@ void set_or_update_runtime_arguments( const BatchNormOperation::tensor_args_t& tensor_args, BatchNormOperation::tensor_return_value_t& c, F handle_args) { - const auto& a = tensor_args.input; - const auto& b = tensor_args.batch_mean; - const auto& d = tensor_args.batch_var; - const auto& e = tensor_args.weight; - const auto& f = tensor_args.bias; + const auto& [a, b, d, e, f, _] = tensor_args; const auto eps = operation_attributes.eps; const bool weight_has_value = e.has_value(); @@ -68,7 +64,7 @@ void set_or_update_runtime_arguments( num_tiles_per_core = num_tiles_per_core_group_2; } else { handle_args(program, reader_kernel_id, core, std::array{0}); - handle_args(program, writer_kernel_id, core, std::array{0}); + handle_args(program, writer_kernel_id, core, std::array{0}); handle_args(program, compute_kernel_id, core, std::array{0}); continue; } @@ -90,14 +86,12 @@ void set_or_update_runtime_arguments( cWt}; handle_args(program, reader_kernel_id, core, reader_runtime_args); - const auto weight_addr = weight_has_value ? e.value().buffer()->address() : 0; - const auto bias_addr = bias_has_value ? f.value().buffer()->address() : 0; + const auto weight_addr = weight_has_value ? e->buffer()->address() : 0; + const auto bias_addr = bias_has_value ? f->buffer()->address() : 0; std::array writer_runtime_args = { b.buffer()->address(), // batch mean d.buffer()->address(), // batch var - static_cast(weight_has_value), - weight_addr, // weight - static_cast(bias_has_value), + weight_addr, // weight bias_addr, // bias c.buffer()->address(), // output start_tile_id, @@ -132,12 +126,7 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch using namespace tt; using namespace tt::tt_metal; - const auto& a = tensor_args.input; - const auto& b = tensor_args.batch_mean; - const auto& d = tensor_args.batch_var; - const auto& eps = operation_attributes.eps; - const auto& e = tensor_args.weight; - const auto& f = tensor_args.bias; + const auto& [a, b, d, e, f, _] = tensor_args; auto program = CreateProgram(); @@ -169,13 +158,6 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch uint32_t num_cores_y = compute_with_storage_grid_size.y; auto all_device_cores = CoreRange({0, 0}, {num_cores_x - 1, num_cores_y - 1}); - Buffer* a_buffer = a.buffer(); - Buffer* b_buffer = b.buffer(); - Buffer* c_buffer = output.buffer(); - Buffer* d_buffer = d.buffer(); - Buffer* e_buffer = nullptr; - Buffer* f_buffer = nullptr; - // Number of tiles to store per input CB (double buffer) constexpr uint32_t num_tiles_per_cb = 2; uint32_t b_num_tiles_per_cb = num_tiles_per_cb; @@ -224,24 +206,12 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch auto [temp_1_cb, temp_1_cb_handle] = create_cb(tt::CBIndex::c_17, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format); - auto a_is_dram = static_cast(a_buffer->buffer_type() == tt_metal::BufferType::DRAM); - auto b_is_dram = static_cast(b_buffer->buffer_type() == tt_metal::BufferType::DRAM); - auto c_is_dram = static_cast(c_buffer->buffer_type() == tt_metal::BufferType::DRAM); - auto d_is_dram = static_cast(d_buffer->buffer_type() == tt_metal::BufferType::DRAM); - bool e_is_dram = false; - bool f_is_dram = false; - - // weight - if (weight_has_value) { - e_buffer = e->buffer(); - e_is_dram = static_cast(e_buffer->buffer_type() == tt_metal::BufferType::DRAM); - } - - // bias - if (bias_has_value) { - f_buffer = f->buffer(); - f_is_dram = static_cast(f_buffer->buffer_type() == tt_metal::BufferType::DRAM); - } + auto a_is_dram = static_cast(a.buffer()->buffer_type() == tt_metal::BufferType::DRAM); + auto b_is_dram = static_cast(b.buffer()->buffer_type() == tt_metal::BufferType::DRAM); + auto c_is_dram = static_cast(output.buffer()->buffer_type() == tt_metal::BufferType::DRAM); + auto d_is_dram = static_cast(d.buffer()->buffer_type() == tt_metal::BufferType::DRAM); + const auto e_is_dram = weight_has_value and e->buffer()->buffer_type() == tt_metal::BufferType::DRAM; + const auto f_is_dram = bias_has_value and f->buffer()->buffer_type() == tt_metal::BufferType::DRAM; // READER KERNEL auto reader_kernel_id = tt_metal::CreateKernel( @@ -255,7 +225,14 @@ BatchNormOperation::BatchNormFactory::cached_program_t BatchNormOperation::Batch program, "ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp", all_device_cores, - tt_metal::WriterDataMovementConfig({b_is_dram, c_is_dram, d_is_dram, e_is_dram, f_is_dram})); + tt_metal::WriterDataMovementConfig( + {b_is_dram, + c_is_dram, + d_is_dram, + e_is_dram, + f_is_dram, + static_cast(weight_has_value), + static_cast(bias_has_value)})); // COMPUTE KERNEL bool fp32_dest_acc_en = c_data_format == tt::DataFormat::UInt32 || c_data_format == tt::DataFormat::Int32 || diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp index 7401bbf5ee2..7416cb88d97 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp @@ -116,7 +116,7 @@ void MAIN { cb_pop_front(cb_den, 1); cb_push_back(cb_affine_or_out, onetile); - if (weight_has_value) { // result = result * weight + if constexpr (weight_has_value) { // result = result * weight cb_reserve_back(cb_scaled_output, onetile); cb_wait_front(cb_affine_or_out, 1); cb_wait_front(cb_weight, 1); @@ -134,7 +134,7 @@ void MAIN { cb_pop_front(cb_weight, 1); cb_push_back(cb_scaled_output, onetile); } - if (bias_has_value) { // result = result + bias + if constexpr (bias_has_value) { // result = result + bias cb_reserve_back(cb_output_0, 1); cb_wait_front(cb_tmp_1, 1); cb_wait_front(cb_bias, 1); diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp index 3fecc5ee29a..44396001114 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/dataflow/writer_batch_norm.cpp @@ -10,18 +10,16 @@ void kernel_main() { uint32_t src_addr = get_arg_val(0); // batch_mean uint32_t batch_var_addr = get_arg_val(1); // batch_var - const bool weight_has_value = get_arg_val(2) == 1; - uint32_t weight_addr = get_arg_val(3); // weight - const bool bias_has_value = get_arg_val(4) == 1; - uint32_t bias_addr = get_arg_val(5); // bias - uint32_t dst_addr = get_arg_val(6); // output - uint32_t start_tile_id = get_arg_val(7); - uint32_t num_tiles = get_arg_val(8); - uint32_t HtWt = get_arg_val(9); - uint32_t n_stride = get_arg_val(10); - uint32_t c_stride = get_arg_val(11); - uint32_t N = get_arg_val(12); - uint32_t C = get_arg_val(13); + uint32_t weight_addr = get_arg_val(2); // weight + uint32_t bias_addr = get_arg_val(3); // bias + uint32_t dst_addr = get_arg_val(4); // output + uint32_t start_tile_id = get_arg_val(5); + uint32_t num_tiles = get_arg_val(6); + uint32_t HtWt = get_arg_val(7); + uint32_t n_stride = get_arg_val(8); + uint32_t c_stride = get_arg_val(9); + uint32_t N = get_arg_val(10); + uint32_t C = get_arg_val(11); constexpr uint32_t onetile = 1; @@ -70,6 +68,9 @@ void kernel_main() { const InterleavedAddrGenFast bias = { .bank_base_address = bias_addr, .page_size = bias_tile_bytes, .data_format = bias_data_format}; + constexpr bool weight_has_value = get_compile_time_arg_val(5) == 1; + constexpr bool bias_has_value = get_compile_time_arg_val(6) == 1; + uint32_t tiles_per_batch = HtWt * C; uint32_t start_n = start_tile_id / tiles_per_batch; uint32_t start_remaining = start_tile_id % tiles_per_batch; @@ -99,7 +100,7 @@ void kernel_main() { fill_tile_with_first_element_bfloat16(cb_id_batch_var); cb_push_back(cb_id_batch_var, onetile); - if (weight_has_value) { // read a tile from weight tensor + if constexpr (weight_has_value) { // read a tile from weight tensor cb_reserve_back(cb_id_weight, onetile); uint32_t l1_weight_write_addr = get_write_ptr(cb_id_weight); noc_async_read_tile(tile_offset, weight, l1_weight_write_addr); @@ -108,7 +109,7 @@ void kernel_main() { cb_push_back(cb_id_weight, onetile); } - if (bias_has_value) { // read a tile from bias tensor + if constexpr (bias_has_value) { // read a tile from bias tensor cb_reserve_back(cb_id_bias, onetile); uint32_t l1_bias_write_addr = get_write_ptr(cb_id_bias); noc_async_read_tile(tile_offset, bias, l1_bias_write_addr);