Skip to content

Commit

Permalink
#0: use vc for mm2d in1 dram sharded
Browse files Browse the repository at this point in the history
  • Loading branch information
yugaoTT committed Jun 6, 2024
1 parent e258ed4 commit 554bdb7
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,6 @@ def test_matmul_in1_dram_sharded_with_mm_chain(
)


@pytest.mark.skipif(is_grayskull(), reason="not tested for GS")
@pytest.mark.parametrize("packer_l1_acc", [True, False], ids=["pack_l1", "no_pack_l1"])
@pytest.mark.parametrize(
"fp32_acc_mode",
Expand Down Expand Up @@ -600,12 +599,18 @@ def test_matmul_2d_in1_dram_sharded(
fused_activation=activation,
)

compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig(
math_fidelity=fidelity,
math_approx_mode=True,
fp32_dest_acc_en=fp32_acc_mode,
packer_l1_acc=packer_l1_acc,
)
if is_grayskull():
compute_kernel_config = ttl.tensor.GrayskullComputeKernelConfig(
math_fidelity=fidelity,
math_approx_mode=True,
)
else:
compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig(
math_fidelity=fidelity,
math_approx_mode=True,
fp32_dest_acc_en=fp32_acc_mode,
packer_l1_acc=packer_l1_acc,
)
if has_bias:
output_t = ttl.operations.primary.matmul(
in0_t,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,11 @@ void kernel_main() {

// RT and COMPILE TIME ARGS for DRAM sharded weights
#ifdef IN1_DRAM_SHARDED
const uint32_t num_dram_shards_to_read = get_arg_val<uint32_t>(18);
const uint32_t dram_tensor_start_offset = get_arg_val<uint32_t>(19);
volatile tt_l1_ptr uint32_t * in1_block_w_dram_stride_bytes = (volatile tt_l1_ptr uint32_t*)get_arg_addr(20);
volatile tt_l1_ptr uint32_t * current_dram_bank_id = (volatile tt_l1_ptr uint32_t*)get_arg_addr(21);
const uint32_t vc = get_arg_val<uint32_t>(18);
const uint32_t num_dram_shards_to_read = get_arg_val<uint32_t>(19);
const uint32_t dram_tensor_start_offset = get_arg_val<uint32_t>(20);
volatile tt_l1_ptr uint32_t * in1_block_w_dram_stride_bytes = (volatile tt_l1_ptr uint32_t*)get_arg_addr(21);
volatile tt_l1_ptr uint32_t * current_dram_bank_id = (volatile tt_l1_ptr uint32_t*)get_arg_addr(22);

constexpr uint32_t in1_dram_block_num_tiles = get_compile_time_arg_val(26);
constexpr uint32_t in1_block_w_dram_bytes= get_compile_time_arg_val(27);
Expand Down Expand Up @@ -180,7 +181,7 @@ void kernel_main() {
uint32_t next_bank_id_and_dram_stride_index = 0;

for (uint32_t i = 0; i < num_dram_shards_to_read; ++i) {
uint32_t in1_base_addr = noc_async_read_tile_dram_sharded_set_state<in1_single_tile_size_bytes, false>(in1_tensor_addr, current_dram_bank_id[next_bank_id_and_dram_stride_index]);
uint32_t in1_base_addr = noc_async_read_tile_dram_sharded_set_state<in1_single_tile_size_bytes, true>(in1_tensor_addr, current_dram_bank_id[next_bank_id_and_dram_stride_index], vc);

if (i == 0) {
in1_base_addr += dram_tensor_start_offset;
Expand Down Expand Up @@ -274,7 +275,7 @@ void kernel_main() {
uint32_t next_bank_id_and_dram_stride_index = 0;

for (uint32_t i = 0; i < num_dram_shards_to_read; ++i) {
uint32_t in3_base_addr = noc_async_read_tile_dram_sharded_set_state<bias_single_tile_size_bytes, false>(in3_tensor_addr, current_dram_bank_id[next_bank_id_and_dram_stride_index]);
uint32_t in3_base_addr = noc_async_read_tile_dram_sharded_set_state<bias_single_tile_size_bytes, true>(in3_tensor_addr, current_dram_bank_id[next_bank_id_and_dram_stride_index], vc);

if (i == 0) {
in3_base_addr += dram_tensor_start_offset;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1(
uint32_t storage_core_stride = 0; // stride in the dram bank
uint32_t curr_worker_core = 0; // current worker core
uint32_t curr_storage_core = 0; // current read dram bank
uint32_t vc = 0;

const auto& cores = grid_to_cores(all_cores.start, all_cores.end, true);
const auto& in0_sender_cores = grid_to_cores(in0_sender.start, in0_sender.end, true);
Expand Down Expand Up @@ -855,6 +856,9 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1(
mm_in1_sender_writer_args.push_back(0);
}

vc = vc == 3 ? 0 : vc+1;
mm_in1_sender_writer_args.push_back(vc);

uint32_t num_iter = 0; // iterate how many banks, till fill the current worker block

if (curr_storage_core < num_dram_banks) {
Expand Down Expand Up @@ -912,7 +916,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1(
worker_core_stride = stride;
}
}
mm_in1_sender_writer_args.insert(mm_in1_sender_writer_args.begin() + 18, num_iter);
mm_in1_sender_writer_args.insert(mm_in1_sender_writer_args.begin() + 19, num_iter);
}
tt_metal::SetRuntimeArgs(
program, mm_kernel_in1_sender_writer_id, core, mm_in1_sender_writer_args); // RISCV_1_default
Expand Down

0 comments on commit 554bdb7

Please sign in to comment.