Skip to content

Commit

Permalink
#0: bmm dram sharded api cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
yugaoTT committed Jun 6, 2024
1 parent 049c4bb commit a614375
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,8 @@ def run_test_matmul_in1_dram_sharded(

program_config = ttl.operations.primary.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig(
in0_block_w=in0_block_w // 4,
out_subblock_h=out_subblock_h,
out_subblock_w=out_subblock_w,
per_core_M=out_block_h,
per_core_N=out_block_w,
fuse_batch=True,
fused_activation=None,
)

Expand Down Expand Up @@ -358,11 +355,8 @@ def run_test_matmul_in1_dram_sharded_mm_chain(

program_config = ttl.operations.primary.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig(
in0_block_w=in0_block_w // 4,
out_subblock_h=out_subblock_h,
out_subblock_w=out_subblock_w,
per_core_M=out_block_h,
per_core_N=out_block_w,
fuse_batch=True,
fused_activation=None,
)

Expand Down
3 changes: 0 additions & 3 deletions tests/ttnn/unit_tests/operations/test_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,8 @@ def test_ttnn_experimental_operations_primary_matmul_dram_sharded(device, m_size

program_config = ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig(
in0_block_w=32,
out_subblock_h=1,
out_subblock_w=4,
per_core_M=1,
per_core_N=4,
fuse_batch=True,
fused_activation=None,
)

Expand Down
16 changes: 2 additions & 14 deletions tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1020,15 +1020,11 @@ void Matmul::validate(
MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig>) {
TT_FATAL(input_tensor_a.is_sharded());
TT_FATAL(this->output_mem_config.is_sharded());
TT_FATAL(program_config.fuse_batch);
TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED);
TT_FATAL(input_tensor_a.memory_config().buffer_type == this->output_mem_config.buffer_type);
TT_FATAL(input_tensor_a.memory_config().memory_layout == this->output_mem_config.memory_layout);
TT_FATAL(input_tensor_a.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR);
uint32_t M =
(program_config.fuse_batch ? input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1]
: input_tensor_a.get_legacy_shape()[-2]) /
TILE_HEIGHT;
uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / TILE_HEIGHT;
uint32_t N = input_tensor_b.get_legacy_shape()[-1] / TILE_WIDTH;
uint32_t K = input_tensor_a.get_legacy_shape()[-1] / TILE_WIDTH;
uint32_t per_core_M = program_config.per_core_M;
Expand All @@ -1041,8 +1037,6 @@ void Matmul::validate(
TT_FATAL(K % program_config.in0_block_w == 0);
TT_FATAL((shard_shape[1] / TILE_WIDTH) % program_config.in0_block_w == 0);

// subbblock constraint
TT_FATAL(program_config.out_subblock_w == per_core_N || program_config.out_subblock_h == 1);
// tensor in1
TT_FATAL(input_tensor_b.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED);
} else if constexpr (std::is_same_v<ProgramConfigType, MatmulMultiCoreReuseMultiCastProgramConfig>) {
Expand Down Expand Up @@ -1244,9 +1238,7 @@ std::vector<Tensor> Matmul::create_output_tensors(const std::vector<Tensor>& inp
} else if constexpr (std::is_same_v<
ProgramConfigType,
MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig>) {
uint32_t M =
(program_config.fuse_batch ? input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1]
: input_tensor_a.get_legacy_shape()[-2]) / TILE_HEIGHT;
uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1];
uint32_t N = input_tensor_b.get_legacy_shape()[-1] / TILE_WIDTH;
auto input_tensor_b_shape = input_tensor_b.get_legacy_shape();

Expand Down Expand Up @@ -1437,14 +1429,10 @@ operation::ProgramWithCallbacks Matmul::create_program(
input_tensor_b,
bias,
output_tensor,
broadcast_batch,
this->compute_kernel_config,
program_config.in0_block_w,
program_config.out_subblock_h,
program_config.out_subblock_w,
program_config.per_core_M,
program_config.per_core_N,
program_config.fuse_batch,
program_config.fused_activation,
this->untilize_out,
false,
Expand Down
11 changes: 1 addition & 10 deletions tt_eager/tt_dnn/op_library/bmm/bmm_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_padding (const Tensor &i
operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_padding (const Tensor &input_tensor_a, const Tensor &input_tensor_b, Tensor& output_tensor, bool bcast_batch);

operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized(const Tensor &input_tensor_a, const Tensor &input_tensor_b, const std::optional<const Tensor> bias, Tensor &output_tensor, bool bcast_batch, CoreCoord compute_with_storage_grid_size, DeviceComputeKernelConfig compute_kernel_config, uint32_t in0_block_w, uint32_t out_subblock_h, uint32_t out_subblock_w, uint32_t per_core_M, uint32_t per_core_N, bool fuse_batch, std::optional<UnaryWithParam> fused_activation, bool mcast_in0, bool untilize_out);
operation::ProgramWithCallbacks matmul_multi_core_reuse_dram_sharded_optimized(const Tensor &input_tensor_a, const Tensor &input_tensor_b, const std::optional<const Tensor> bias, Tensor &output_tensor, bool bcast_batch, DeviceComputeKernelConfig compute_kernel_config, uint32_t in0_block_w, uint32_t out_subblock_h, uint32_t out_subblock_w, uint32_t per_core_M, uint32_t per_core_N, bool fuse_batch, std::optional<UnaryWithParam> fused_activation, bool untilize_out, bool skip_compute, bool skip_in0_mcast, bool skip_write_back);
operation::ProgramWithCallbacks matmul_multi_core_reuse_dram_sharded_optimized(const Tensor &input_tensor_a, const Tensor &input_tensor_b, const std::optional<const Tensor> bias, Tensor &output_tensor, DeviceComputeKernelConfig compute_kernel_config, uint32_t in0_block_w, uint32_t per_core_M, uint32_t per_core_N, std::optional<UnaryWithParam> fused_activation, bool untilize_out, bool skip_compute, bool skip_in0_mcast, bool skip_write_back);
operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_2d_optimized(const Tensor &input_tensor_a, const Tensor &input_tensor_b, const std::optional<const Tensor> bias, Tensor &output_tensor, bool bcast_batch, CoreCoord compute_with_storage_grid_size, DeviceComputeKernelConfig compute_kernel_config, uint32_t in0_block_w, uint32_t out_subblock_h, uint32_t out_subblock_w, uint32_t per_core_M, uint32_t per_core_N, bool fuse_batch, bool transpose_mcast, std::optional<UnaryWithParam> fused_activation, bool untilize_out);
operation::ProgramWithCallbacks bmm_multi_core_reuse_optimized(const Tensor& input_tensor_a, const Tensor& input_tensor_b, Tensor &output_tensor, bool bcast_batch, CoreCoord compute_with_storage_grid_size, tt::tt_metal::DataType output_dtype, DeviceComputeKernelConfig compute_kernel_config, uint32_t in0_block_w, uint32_t out_subblock_h, uint32_t out_subblock_w, uint32_t per_core_M, uint32_t per_core_N, bool fuse_batch, bool untilize_out);

Expand Down Expand Up @@ -220,29 +220,20 @@ struct MatmulMultiCoreReuseMultiCast1DProgramConfig {

struct MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig {
std::size_t in0_block_w;
std::size_t out_subblock_h;
std::size_t out_subblock_w;
std::size_t per_core_M;
std::size_t per_core_N;
bool fuse_batch;
std::optional<UnaryWithParam> fused_activation;

static constexpr auto attribute_names = std::make_tuple(
"in0_block_w",
"out_subblock_h",
"out_subblock_w",
"per_core_M",
"per_core_N",
"fuse_batch",
"fused_activation");
const auto attribute_values() const {
return std::make_tuple(
std::cref(this->in0_block_w),
std::cref(this->out_subblock_h),
std::cref(this->out_subblock_w),
std::cref(this->per_core_M),
std::cref(this->per_core_N),
std::cref(this->fuse_batch),
std::cref(this->fused_activation));
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,7 @@ operation::ProgramWithCallbacks create_program_dram_sharded(
uint32_t M,
uint32_t N,
uint32_t K,
bool bcast_batch,
uint32_t in0_block_w,
uint32_t out_subblock_h_storage,
uint32_t out_subblock_w_storage,
uint32_t per_core_M,
uint32_t per_core_N_storage,
std::optional<UnaryWithParam> fused_activation,
Expand Down Expand Up @@ -1117,14 +1114,10 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_dram_sharded_optimized_(
const Tensor& b,
const std::optional<const Tensor> bias,
Tensor& output,
bool bcast_batch,
DeviceComputeKernelConfig compute_kernel_config,
uint32_t in0_block_w,
uint32_t out_subblock_h,
uint32_t out_subblock_w,
uint32_t per_core_M,
uint32_t per_core_N,
bool fuse_batch,
std::optional<UnaryWithParam> fused_activation,
bool untilize_out,
bool skip_compute,
Expand Down Expand Up @@ -1159,26 +1152,22 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_dram_sharded_optimized_(
uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format);
tt_metal::Buffer* in0_buffer = a.buffer();
tt_metal::Buffer* in1_buffer = b.buffer();
if (bcast_batch)
TT_FATAL(ashape.rank() == bshape.rank() && ashape.rank() >= 2 && "bmm (non-bcast matmul) expects input tensors of the same rank and must have rank >= 2");
for (auto i = 0; i < ashape.rank() - 2; i++) {
TT_FATAL(
bshape[0] * bshape[1] == 1 &&
"matmul (batch bcast variant) expects input tensors of shapes BCMK*11KN=BCMN");
else {
// same condition as above, different message
TT_FATAL(
ashape[1] == bshape[1] && ashape[0] == bshape[0] &&
"bmm (non-bcast matmul) expects input tensors of shapes BCMK*BCKN=BCMN");
ashape[i] == bshape[i] &&
"bmm (non-bcast matmul) expects input tensors of shapes BCMK*BCKN=BCMN or equivalent");
}
TT_FATAL(in0_buffer->size() % in0_single_tile_size == 0);
TT_FATAL(in1_buffer->size() % in1_single_tile_size == 0);

TT_FATAL(
ashape[3] == bshape[2] &&
"Dimension K (A.shape[3] and B.shape[2]) must match for A and B in bmm_op"); // A.K == B.K
TT_FATAL(ashape[2] % TILE_HEIGHT == 0);
TT_FATAL(ashape[3] % TILE_WIDTH == 0);
TT_FATAL(bshape[2] % TILE_HEIGHT == 0);
TT_FATAL(bshape[3] % TILE_WIDTH == 0);
ashape[-1] == bshape[-2] &&
"Dimension K (A.shape[-1] and B.shape[-2]) must match for A and B in bmm_op"); // A.K == B.K
TT_FATAL(ashape[-2] % TILE_HEIGHT == 0);
TT_FATAL(ashape[-1] % TILE_WIDTH == 0);
TT_FATAL(bshape[-2] % TILE_HEIGHT == 0);
TT_FATAL(bshape[-1] % TILE_WIDTH == 0);

MathFidelity math_fidelity;
bool math_approx_mode;
Expand Down Expand Up @@ -1211,15 +1200,11 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_dram_sharded_optimized_(
////////////////////////////////////////////////////////////////////////////
// NOTE: Pads matmul input dims to 512 x 512 multiples (ie. multiples of 16*32 x 16*32)
// NOTE: Maximum number of tiles in output is 120 * 16^2 = 30,720 (eg. [1, 1, 5120, 6144])
uint32_t B = ashape[0] * ashape[1];
uint32_t Mt = ashape[2] / TILE_HEIGHT;
uint32_t Kt = ashape[3] / TILE_WIDTH;
uint32_t Nt = bshape[3] / TILE_WIDTH;

if (fuse_batch) {
Mt = B * Mt;
B = 1;
}
uint32_t B = 1;
uint32_t Mt = get_batch_size(ashape) * ashape[-2] / TILE_HEIGHT;
uint32_t Kt = ashape[-1] / TILE_WIDTH;
uint32_t Nt = bshape[-1] / TILE_WIDTH;

TT_FATAL(Kt % in0_block_w == 0);

////////////////////////////////////////////////////////////////////////////
Expand All @@ -1242,10 +1227,7 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_dram_sharded_optimized_(
Mt,
Nt,
Kt,
bcast_batch,
in0_block_w,
out_subblock_h,
out_subblock_w,
per_core_M,
per_core_N,
fused_activation,
Expand All @@ -1268,14 +1250,10 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_dram_sharded_optimized(
const Tensor& b,
const std::optional<const Tensor> bias,
Tensor& output_tensor,
bool broadcast_batch,
DeviceComputeKernelConfig compute_kernel_config,
uint32_t in0_block_w,
uint32_t out_subblock_h,
uint32_t out_subblock_w,
uint32_t per_core_M,
uint32_t per_core_N,
bool fuse_batch,
std::optional<UnaryWithParam> fused_activation,
bool untilize_out,
bool skip_compute,
Expand All @@ -1286,14 +1264,10 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_dram_sharded_optimized(
b,
bias,
output_tensor,
broadcast_batch,
compute_kernel_config,
in0_block_w,
out_subblock_h,
out_subblock_w,
per_core_M,
per_core_N,
fuse_batch,
fused_activation,
untilize_out,
skip_compute,
Expand Down
6 changes: 0 additions & 6 deletions tt_eager/tt_lib/csrc/operations/primary/module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,17 +126,11 @@ void py_module(py::module& m_primary) {
std::size_t,
std::size_t,
std::size_t,
std::size_t,
std::size_t,
bool,
std::optional<UnaryWithParam>>(),
py::kw_only(),
py::arg("in0_block_w").noconvert(),
py::arg("out_subblock_h").noconvert(),
py::arg("out_subblock_w").noconvert(),
py::arg("per_core_M").noconvert(),
py::arg("per_core_N").noconvert(),
py::arg("fuse_batch").noconvert(),
py::arg("fused_activation"))
.def_readwrite("fused_activation", &MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig::fused_activation)
.def("__repr__", [](const MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig& config) {
Expand Down

0 comments on commit a614375

Please sign in to comment.