Skip to content

Commit

Permalink
#16626: Add support for unpadded shapes in Matmul1D w/ gather_in0 (#1…
Browse files Browse the repository at this point in the history
…6627)

### Ticket
- #16626

### Problem description
In the current use case of Matmul1D with gather_in0 in the Llama models,
the activations and weights need to be padded. This results in
significant overhead.

### What's changed
- Added support to skip part of in0_block_w that is padding information
- Pad the Kt and Nt in the host code for gather_in0

### Checklist
- [x] Post commit CI passes
(https://github.com/tenstorrent/tt-metal/actions/runs/12893880800)
- [x] New/Existing tests provide coverage for changes
(https://github.com/tenstorrent/tt-metal/actions/runs/12893883783)
  • Loading branch information
avoraTT authored Jan 24, 2025
1 parent 4649519 commit c8b0fa8
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ def get_physical_to_logical_core_mapping(device):
return mapping


def round_up(a, b):
"""
Round up a to the nearest multiple of b
"""
return b * math.ceil(a / b)


# physical coords
PREFETCHER_GRID = [
(8, 11),
Expand Down Expand Up @@ -147,13 +154,21 @@ def run_multi_core_matmul_1d(

M *= B # Fuse batch always enabled

K_per_shard = round_up(math.ceil(K / num_cores), ttnn.TILE_SIZE)
K_padded = K_per_shard * num_cores
N_per_shard = round_up(math.ceil(N / num_cores), ttnn.TILE_SIZE)
N_padded = N_per_shard * num_cores

in0_block_h = M // ttnn.TILE_SIZE
in0_block_w = K // num_cores // ttnn.TILE_SIZE
while (K / ttnn.TILE_SIZE) % in0_block_w != 0:
in0_block_w -= 1

out_block_h = M // ttnn.TILE_SIZE
out_block_w = N // num_cores // ttnn.TILE_SIZE
out_block_w = N_padded // num_cores // ttnn.TILE_SIZE

num_blocks_y = (M // ttnn.TILE_SIZE - 1) // out_block_h + 1
num_blocks_x = (N // ttnn.TILE_SIZE - 1) // out_block_w + 1
num_blocks_x = (N_padded // ttnn.TILE_SIZE - 1) // out_block_w + 1
num_blocks_total = num_blocks_y * num_blocks_x

if num_blocks_total != num_cores:
Expand Down Expand Up @@ -217,7 +232,7 @@ def run_multi_core_matmul_1d(
ttnn.BufferType.L1,
ttnn.ShardSpec(
core_range_set,
[M, K // num_cores],
[M, K_per_shard],
ttnn.ShardOrientation.ROW_MAJOR,
),
)
Expand All @@ -227,7 +242,7 @@ def run_multi_core_matmul_1d(
ttnn.BufferType.L1,
ttnn.ShardSpec(
core_range_set,
[K, N // num_cores],
[K_padded, N_per_shard],
ttnn.ShardOrientation.ROW_MAJOR,
),
)
Expand All @@ -237,7 +252,7 @@ def run_multi_core_matmul_1d(
ttnn.BufferType.L1,
ttnn.ShardSpec(
core_range_set,
[M, N // num_cores],
[M, N_per_shard],
ttnn.ShardOrientation.ROW_MAJOR,
),
)
Expand Down Expand Up @@ -313,6 +328,79 @@ def run_multi_core_matmul_1d(
assert device.num_program_cache_entries() == 1 # Only 1 op


@pytest.mark.skipif(is_grayskull(), reason="GS does not support fp32")
@pytest.mark.skipif(is_blackhole(), reason="Test suite for GS only")
@pytest.mark.parametrize("has_bias", [False], ids=["no_bias"])
@pytest.mark.parametrize(
"B, M, K, N, in0_dtype, in1_dtype, fidelity, packer_l1_acc, fp32_acc_mode, grid",
[
(1, 32, 2048, 1280, ttnn.bfloat16, ttnn.bfloat8_b, ttnn.MathFidelity.HiFi2, True, True, (8, 3)),
(1, 32, 1280, 2048, ttnn.bfloat16, ttnn.bfloat8_b, ttnn.MathFidelity.HiFi2, True, True, (8, 3)),
(1, 32, 2048, 3584, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.MathFidelity.LoFi, True, False, (8, 3)),
(1, 32, 2048, 3584, ttnn.bfloat8_b, ttnn.bfloat4_b, ttnn.MathFidelity.LoFi, True, False, (8, 3)),
(1, 32, 3584, 2048, ttnn.bfloat16, ttnn.bfloat8_b, ttnn.MathFidelity.HiFi2, True, False, (8, 3)),
(1, 32, 96, 64, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.MathFidelity.LoFi, True, True, (2, 1)),
],
)
@pytest.mark.parametrize(
"activation",
[
None,
],
)
@pytest.mark.parametrize(
"use_arbitrary_cores, hop_grid",
[
(False, None),
(False, [(3, 6)]),
],
)
@pytest.mark.parametrize(
"num_iters",
[
1,
],
)
def test_multi_core_matmul_1d_pad_wh(
device,
in0_dtype,
in1_dtype,
fidelity,
has_bias,
fp32_acc_mode,
packer_l1_acc,
B,
M,
K,
N,
activation,
grid,
hop_grid,
use_arbitrary_cores,
num_iters,
use_program_cache,
function_level_defaults,
):
run_multi_core_matmul_1d(
device,
in0_dtype,
in1_dtype,
fidelity,
has_bias,
fp32_acc_mode,
packer_l1_acc,
B,
M,
K,
N,
activation,
grid,
use_arbitrary_cores,
num_iters,
hop_grid=hop_grid,
)


@pytest.mark.skipif(is_grayskull(), reason="GS does not support fp32")
@pytest.mark.skipif(is_blackhole(), reason="Test suite for GS only")
@pytest.mark.parametrize("has_bias", [False], ids=["no_bias"])
Expand Down Expand Up @@ -447,7 +535,6 @@ def test_multi_core_matmul_1d_wh(
"num_iters",
[1, 3],
)
@pytest.mark.parametrize("device_params", [{"dispatch_core_axis": ttnn.DispatchCoreAxis.COL}], indirect=True)
def test_multi_core_matmul_1d_ring_hop_wh(
device,
in0_dtype,
Expand Down
37 changes: 29 additions & 8 deletions tests/ttnn/unit_tests/operations/prefetcher_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
import torch
import ttnn
import math
from loguru import logger

from ttnn import ReplicateTensorToMesh, ShardTensor2dMesh, ConcatMeshToTensor, ConcatMesh2dToTensor
Expand All @@ -20,6 +21,7 @@
run_multi_core_matmul_1d,
PREFETCHER_NOC1_GRID,
num_cores_to_rectangle_grid,
round_up,
)


Expand Down Expand Up @@ -210,6 +212,17 @@ def run_prefetcher_mm(
dram_core_range_set = ttnn.CoreRangeSet([ttnn.CoreRange(core_coord, core_coord) for core_coord in dram_cores])
sender_core_range_set = ttnn.CoreRangeSet([ttnn.CoreRange(core_coord, core_coord) for core_coord in sender_cores])

padded_shapes, shard_shapes = [], []
for K, N in input_shapes:
num_cores = len(receiver_cores_list)
K_per_shard = round_up(math.ceil(K / num_cores), ttnn.TILE_SIZE)
K_padded = K_per_shard * num_cores
N_per_shard = round_up(math.ceil(N / num_cores), ttnn.TILE_SIZE)
N_padded = N_per_shard * num_cores

padded_shapes.append((K_padded, N_padded))
shard_shapes.append((K_per_shard, N_per_shard))

cluster_shape = None
mesh_mapper = None
mesh_composer = None
Expand All @@ -225,7 +238,7 @@ def run_prefetcher_mm(

tt_tensors_all = []
for tid in range(num_tensors * num_layers):
K, N = input_shapes[tid % num_tensors]
K, N = padded_shapes[tid % num_tensors]
input_sharded_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.DRAM,
Expand Down Expand Up @@ -286,15 +299,22 @@ def run_prefetcher_mm(
block_dims = []
for tid in range(num_tensors):
K, N = input_shapes[tid]
_, N_padded = padded_shapes[tid]
K_per_shard, N_per_shard = shard_shapes[tid]

in0_shape = [1, 1, M, K]
in0_shapes.append(in0_shape)
out_shape = [1, 1, M, N]

out_shape = [1, 1, M, N_per_shard]
out_shapes.append(out_shape)

in0_block_h = M // ttnn.TILE_SIZE
in0_block_w = K // num_cores // ttnn.TILE_SIZE
while (K / ttnn.TILE_SIZE) % in0_block_w != 0:
in0_block_w -= 1

out_block_h = M // ttnn.TILE_SIZE
out_block_w = N // num_cores // ttnn.TILE_SIZE
out_block_w = N_padded // num_cores // ttnn.TILE_SIZE

out_subblock_h = 1
out_subblock_w = max_dst_tiles
Expand Down Expand Up @@ -347,33 +367,34 @@ def run_prefetcher_mm(

output_mem_configs = []
for shape in out_shapes:
_, _, M, N = shape
_, _, M, N_per_shard = shape

output_sharded_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.L1,
ttnn.ShardSpec(
output_core_range_set,
[M, N // num_cores],
[M, N_per_shard],
ttnn.ShardOrientation.ROW_MAJOR,
),
)
output_mem_configs.append(output_sharded_mem_config)

in0_tensors = []
in0_t_tensors = []
for shape in in0_shapes:
for shape, shard_shape in zip(in0_shapes, shard_shapes):
in0 = torch.randn(shape)
in0_tensors.append(in0)

_, _, M, K = shape
_, _, M, _ = shape
K_per_shard, _ = shard_shape

in0_sharded_mem_config = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.L1,
ttnn.ShardSpec(
input_core_range_set,
[M, K // num_cores],
[M, K_per_shard],
ttnn.ShardOrientation.ROW_MAJOR,
),
)
Expand Down
8 changes: 8 additions & 0 deletions tests/ttnn/unit_tests/operations/test_prefetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@
(2, 3, [(256, 1024), (256, 2048), (512, 256)], [ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.bfloat4_b], 5),
(2, 2, [(256, 1024), (128, 128)], [ttnn.bfloat4_b, ttnn.bfloat8_b], 5),
(2, 3, [(256, 1024), (128, 128), (1024, 256)], [ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.bfloat4_b], 5),
# Padding check
(
2,
3,
[(256 + 32, 512 + 224), (128, 128 + 64), (512 + 256, 224)],
[ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.bfloat4_b],
5,
),
],
)
@pytest.mark.parametrize(
Expand Down
11 changes: 9 additions & 2 deletions tests/ttnn/unit_tests/operations/test_prefetcher_TG.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,20 @@
(12, 5, [(3840, 2304)] * 5, [ttnn.bfloat8_b] * 5, 5), # FF2
(12, 6, [(2304, 1536)] * 6, [ttnn.bfloat8_b] * 6, 5), # QKV
(12, 5, [(2304, 2304)] * 5, [ttnn.bfloat8_b] * 5, 5), # DO
(
12,
5,
[(2304, 1536), (1536, 2304), (2304, 3840), (2304, 3840), (3840, 2304)],
[ttnn.bfloat8_b, ttnn.bfloat8_b, ttnn.bfloat4_b, ttnn.bfloat4_b, ttnn.bfloat8_b],
5,
), # qkv + do + ff1 + ff3 + ff2
# Takes really long to set up
(
12,
5,
[(2304, 1536), (2304, 2304), (2304, 3840), (2304, 3840), (3840, 2304)],
[(2048, 1280), (1280, 2048), (2048, 3584), (2048, 3584), (3584, 2048)],
[ttnn.bfloat8_b, ttnn.bfloat8_b, ttnn.bfloat4_b, ttnn.bfloat4_b, ttnn.bfloat8_b],
80, # DRAM OOM issue?
80,
), # qkv + do + ff1 + ff3 + ff2
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,7 @@ FORCE_INLINE void update_rd_ptr_to_ring_index(
}

void MAIN {
// Runtime args
uint32_t rt_args_idx = 0;
uint32_t ring_idx = get_arg_val<uint32_t>(rt_args_idx++);

// Compile time args
constexpr uint32_t in0_block_w = get_compile_time_arg_val(0); // inner block size in tiles
constexpr uint32_t in0_num_subblocks = get_compile_time_arg_val(1); // outer row block size (in inner row blocks)
constexpr uint32_t in0_block_num_tiles =
Expand All @@ -153,6 +150,13 @@ void MAIN {
constexpr uint32_t batch = get_compile_time_arg_val(13); // batch dim
constexpr uint32_t out_block_num_tiles = get_compile_time_arg_val(14); // number of tiles in out_block
constexpr bool untilize_out = get_compile_time_arg_val(15); // untilize output
constexpr uint32_t ring_size = num_blocks;

// Runtime args
uint32_t rt_args_idx = 0;
uint32_t ring_idx = get_arg_val<uint32_t>(rt_args_idx++);
const uint32_t* unpadded_in0_shard_widths_in_tiles = (uint32_t*)get_arg_addr(rt_args_idx);
rt_args_idx += ring_size;

constexpr uint32_t out_block_w = out_subblock_w * in1_num_subblocks;

Expand Down Expand Up @@ -216,6 +220,9 @@ void MAIN {
cb_pop_front(sync_cb2, 1);

for (uint32_t block = 0; block < num_blocks; block++) {
const uint32_t curr_ring_idx = (ring_idx + block) % ring_size;
uint32_t unpadded_in0_block_w = unpadded_in0_shard_widths_in_tiles[curr_ring_idx];

const uint32_t input0_cb_id = block == 0 ? in0_cb_id : in2_cb_id;
bool last_out = block == (num_blocks - 1);
// Configure packer once for pack out without Bias
Expand Down Expand Up @@ -251,7 +258,7 @@ void MAIN {
#ifdef ENABLE_GLOBAL_CB
int in1_index_subblock_offset = 0;
#else
int in1_index_subblock_offset = in1_block_num_tiles * ((ring_idx + block) % num_blocks);
int in1_index_subblock_offset = in1_block_num_tiles * (curr_ring_idx);
#endif
for (uint32_t in1_subblock = 0; in1_subblock < in1_num_subblocks; in1_subblock++) {
tile_regs_acquire();
Expand All @@ -273,7 +280,7 @@ void MAIN {
uint32_t in0_index = in0_index_subblock_offset; // offset into in0 block
uint32_t in1_index = in1_index_subblock_offset; // offset into in1 block
// inner dim that we accumualte is the inner dim of in0/in1, which is in0_block_w
for (uint32_t inner_dim_idx = 0; inner_dim_idx < in0_block_w; ++inner_dim_idx) {
for (uint32_t inner_dim_idx = 0; inner_dim_idx < unpadded_in0_block_w; ++inner_dim_idx) {
// matmul outer product of (out_subblock_h x out_subblock_w) tiles that fill dst
// accumulation is done by iterating matmul_block across inner dim
// in0_block_w is passed as innder dim (kt) to matmul_block, interally used to stride in0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ void kernel_main() {
uint32_t noc = get_arg_val<uint32_t>(rt_args_idx++);
bool is_hop_core = (bool)get_arg_val<uint32_t>(rt_args_idx++);
bool end_of_hop = (bool)get_arg_val<uint32_t>(rt_args_idx++);
const uint32_t* unpadded_in0_shard_widths_in_tiles = (uint32_t*)get_arg_addr(rt_args_idx);
rt_args_idx += ring_size;

volatile tt_l1_ptr uint32_t* l1_signal_sem_addr =
reinterpret_cast<volatile tt_l1_ptr uint32_t*>(signal_semaphore_addr);
Expand All @@ -48,6 +50,9 @@ void kernel_main() {

for (uint32_t b = 0; b < batch; ++b) {
for (uint32_t shard_cnt = hop_core_offset; shard_cnt < ring_size; shard_cnt++) {
uint32_t curr_ring_idx = (ring_idx + shard_cnt) % ring_size;
bool skip_send = unpadded_in0_shard_widths_in_tiles[curr_ring_idx] == 0 && !is_hop_core;

uint32_t curr_shard_write_addr = l1_write_addr_in0 + shard_size_bytes * (shard_cnt - hop_core_offset);
uint64_t remote_curr_shard_write_addr =
get_noc_addr(next_core_noc_x, next_core_noc_y, curr_shard_write_addr, noc);
Expand All @@ -59,7 +64,9 @@ void kernel_main() {

// Send data to next core
if (shard_cnt < ring_size - 1 || is_hop_core) { // Skip sending the last shard
noc_async_write(curr_shard_read_addr, remote_curr_shard_write_addr, shard_size_bytes, noc);
if (!skip_send) {
noc_async_write(curr_shard_read_addr, remote_curr_shard_write_addr, shard_size_bytes, noc);
}

// Signal the next core that data is ready
noc_semaphore_inc(remote_signal_semaphore_addr, 1, noc);
Expand Down
4 changes: 3 additions & 1 deletion ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1538,7 +1538,9 @@ void Matmul::validate(
TT_FATAL(M == per_core_M, "Error");
TT_FATAL(per_core_M == (shard_shape[0] / in0_tile_shape[0]), "Error");
TT_FATAL(K % program_config.in0_block_w == 0, "Error");
TT_FATAL((shard_shape[1] / in0_tile_shape[1]) % program_config.in0_block_w == 0, "Error");
if (!program_config.gather_in0) { // Padding allowed for gather_in0
TT_FATAL((shard_shape[1] / in0_tile_shape[1]) % program_config.in0_block_w == 0, "Error");
}
}
if (this->output_mem_config.is_sharded()) {
TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED, "Error");
Expand Down
Loading

0 comments on commit c8b0fa8

Please sign in to comment.