diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_matmul_dram_sharded.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_matmul_dram_sharded.py index fed420eacdb0..4f51badab2b0 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_matmul_dram_sharded.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_matmul_dram_sharded.py @@ -469,7 +469,6 @@ def test_matmul_in1_dram_sharded_with_mm_chain( ) -@pytest.mark.skipif(is_grayskull(), reason="not tested for GS") @pytest.mark.parametrize("packer_l1_acc", [True, False], ids=["pack_l1", "no_pack_l1"]) @pytest.mark.parametrize( "fp32_acc_mode", @@ -600,12 +599,18 @@ def test_matmul_2d_in1_dram_sharded( fused_activation=activation, ) - compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig( - math_fidelity=fidelity, - math_approx_mode=True, - fp32_dest_acc_en=fp32_acc_mode, - packer_l1_acc=packer_l1_acc, - ) + if is_grayskull(): + compute_kernel_config = ttl.tensor.GrayskullComputeKernelConfig( + math_fidelity=fidelity, + math_approx_mode=True, + ) + else: + compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig( + math_fidelity=fidelity, + math_approx_mode=True, + fp32_dest_acc_en=fp32_acc_mode, + packer_l1_acc=packer_l1_acc, + ) if has_bias: output_t = ttl.operations.primary.matmul( in0_t, diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp b/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp index b9fd148ec8be..2d55dc25a541 100644 --- a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp @@ -96,10 +96,11 @@ void kernel_main() { // RT and COMPILE TIME ARGS for DRAM sharded weights #ifdef IN1_DRAM_SHARDED - const uint32_t num_dram_shards_to_read = get_arg_val(18); - const uint32_t dram_tensor_start_offset = get_arg_val(19); - volatile tt_l1_ptr uint32_t * in1_block_w_dram_stride_bytes = (volatile tt_l1_ptr uint32_t*)get_arg_addr(20); - volatile tt_l1_ptr uint32_t * current_dram_bank_id = (volatile tt_l1_ptr uint32_t*)get_arg_addr(21); + const uint32_t vc = get_arg_val(18); + const uint32_t num_dram_shards_to_read = get_arg_val(19); + const uint32_t dram_tensor_start_offset = get_arg_val(20); + volatile tt_l1_ptr uint32_t * in1_block_w_dram_stride_bytes = (volatile tt_l1_ptr uint32_t*)get_arg_addr(21); + volatile tt_l1_ptr uint32_t * current_dram_bank_id = (volatile tt_l1_ptr uint32_t*)get_arg_addr(22); constexpr uint32_t in1_dram_block_num_tiles = get_compile_time_arg_val(26); constexpr uint32_t in1_block_w_dram_bytes= get_compile_time_arg_val(27); @@ -180,7 +181,7 @@ void kernel_main() { uint32_t next_bank_id_and_dram_stride_index = 0; for (uint32_t i = 0; i < num_dram_shards_to_read; ++i) { - uint32_t in1_base_addr = noc_async_read_tile_dram_sharded_set_state(in1_tensor_addr, current_dram_bank_id[next_bank_id_and_dram_stride_index]); + uint32_t in1_base_addr = noc_async_read_tile_dram_sharded_set_state(in1_tensor_addr, current_dram_bank_id[next_bank_id_and_dram_stride_index], vc); if (i == 0) { in1_base_addr += dram_tensor_start_offset; @@ -274,7 +275,7 @@ void kernel_main() { uint32_t next_bank_id_and_dram_stride_index = 0; for (uint32_t i = 0; i < num_dram_shards_to_read; ++i) { - uint32_t in3_base_addr = noc_async_read_tile_dram_sharded_set_state(in3_tensor_addr, current_dram_bank_id[next_bank_id_and_dram_stride_index]); + uint32_t in3_base_addr = noc_async_read_tile_dram_sharded_set_state(in3_tensor_addr, current_dram_bank_id[next_bank_id_and_dram_stride_index], vc); if (i == 0) { in3_base_addr += dram_tensor_start_offset; diff --git a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp b/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp index 815a23f119dd..f5b333ee4dc8 100644 --- a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp @@ -689,6 +689,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( uint32_t storage_core_stride = 0; // stride in the dram bank uint32_t curr_worker_core = 0; // current worker core uint32_t curr_storage_core = 0; // current read dram bank + uint32_t vc = 0; const auto& cores = grid_to_cores(all_cores.start, all_cores.end, true); const auto& in0_sender_cores = grid_to_cores(in0_sender.start, in0_sender.end, true); @@ -855,6 +856,9 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( mm_in1_sender_writer_args.push_back(0); } + vc = vc == 3 ? 0 : vc+1; + mm_in1_sender_writer_args.push_back(vc); + uint32_t num_iter = 0; // iterate how many banks, till fill the current worker block if (curr_storage_core < num_dram_banks) { @@ -912,7 +916,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( worker_core_stride = stride; } } - mm_in1_sender_writer_args.insert(mm_in1_sender_writer_args.begin() + 18, num_iter); + mm_in1_sender_writer_args.insert(mm_in1_sender_writer_args.begin() + 19, num_iter); } tt_metal::SetRuntimeArgs( program, mm_kernel_in1_sender_writer_id, core, mm_in1_sender_writer_args); // RISCV_1_default