From 7d03e02a888879767aa51378420a3c1c63228d97 Mon Sep 17 00:00:00 2001 From: Brian Liu Date: Fri, 5 Jul 2024 21:57:27 +0000 Subject: [PATCH] #9849: Move checks on batch dims for matmul to validate - This adds these checks to matmul_multicore and matmul_multicore_reuse as an intended side effect --- tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp | 21 ++++++++++++++++--- ...op_multi_core_reuse_mcast_1d_optimized.cpp | 13 ------------ ...op_multi_core_reuse_mcast_2d_optimized.cpp | 13 ------------ ...ulti_core_reuse_dram_sharded_optimized.cpp | 6 ------ .../bmm_op_multi_core_reuse_optimized.cpp | 13 ------------ 5 files changed, 18 insertions(+), 48 deletions(-) diff --git a/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp b/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp index 9965283668f..ee05a7fd8f0 100644 --- a/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp @@ -750,8 +750,9 @@ void Matmul::validate( TT_FATAL(input_tensors.size() == 2); const auto& input_tensor_a = input_tensors.at(0); const auto& input_tensor_b = input_tensors.at(1); - auto a_shape = input_tensor_a.get_shape(); - auto b_shape = input_tensor_b.get_shape(); + const auto& a_shape = input_tensor_a.get_shape(); + const auto& b_shape = input_tensor_b.get_shape(); + TT_FATAL( (input_tensor_a.get_layout() == Layout::TILE && input_tensor_b.get_layout() == Layout::TILE), "Inputs to matmul must be tilized"); @@ -761,6 +762,20 @@ void Matmul::validate( a_shape[-1], b_shape[-2]); + if (this->bcast_batch) { + TT_FATAL( + get_batch_size(b_shape) == 1 && + "matmul (batch bcast variant) expects input tensors of shapes BCMK*11KN=BCMN or equivalent"); + } else { + // same condition as above, different message + TT_FATAL(a_shape.rank() == b_shape.rank() && "bmm (non-bcast matmul) expects input tensors of the same rank"); + for (auto i = 0; i < a_shape.rank() - 2; i++) { + TT_FATAL( + a_shape[i] == b_shape[i] && + "bmm (non-bcast matmul) expects input tensors of shapes BCMK*BCKN=BCMN or equivalent"); + } + } + TT_FATAL(is_floating_point(input_tensor_a.get_dtype()), "Unsupported data format"); TT_FATAL( input_tensor_a.storage_type() == StorageType::DEVICE and input_tensor_b.storage_type() == StorageType::DEVICE, @@ -781,7 +796,7 @@ void Matmul::validate( uint32_t bias_batch_size = get_batch_size(bias_shape); TT_FATAL(bias_batch_size == 1, "Unsupported bias shape: batch size not equal to 1."); TT_FATAL(bias_shape[-2] == TILE_HEIGHT, "Unsupported bias shape: second last dimension not equal to tile height"); - TT_FATAL(bias_shape[-1] == input_tensor_b.get_legacy_shape()[-1], "Unsupported bias shape: last dimension not equal to second input's last dimension."); + TT_FATAL(bias_shape[-1] == b_shape[-1], "Unsupported bias shape: last dimension not equal to second input's last dimension."); } if (this->untilize_out) { diff --git a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp b/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp index 9bf58816b61..1b700509f83 100644 --- a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp @@ -1509,19 +1509,6 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format); tt_metal::Buffer* in0_buffer = a.buffer(); tt_metal::Buffer* in1_buffer = b.buffer(); - if (bcast_batch) - TT_FATAL( - get_batch_size(bshape) == 1 && - "matmul (batch bcast variant) expects input tensors of shapes BCMK*11KN=BCMN or equivalent"); - else { - // same condition as above, different message - TT_FATAL(ashape.rank() == bshape.rank() && "bmm (non-bcast matmul) expects input tensors of the same rank"); - for (auto i = 0; i < ashape.rank() - 2; i++) { - TT_FATAL( - ashape[i] == bshape[i] && - "bmm (non-bcast matmul) expects input tensors of shapes BCMK*BCKN=BCMN or equivalent"); - } - } TT_FATAL(in0_buffer->size() % in0_single_tile_size == 0); TT_FATAL(in1_buffer->size() % in1_single_tile_size == 0); diff --git a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp b/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp index 4f9623edca3..86cd8d442a6 100644 --- a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_2d_optimized/bmm_op_multi_core_reuse_mcast_2d_optimized.cpp @@ -1232,19 +1232,6 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_2d_optimized_( uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format); tt_metal::Buffer* in0_buffer = a.buffer(); tt_metal::Buffer* in1_buffer = b.buffer(); - if (bcast_batch) - TT_FATAL( - get_batch_size(bshape) == 1 && - "matmul (batch bcast variant) expects input tensors of shapes BCMK*11KN=BCMN or equivalent"); - else { - // same condition as above, different message - TT_FATAL(ashape.rank() == bshape.rank() && "bmm (non-bcast matmul) expects input tensors of the same rank"); - for (auto i = 0; i < ashape.rank() - 2; i++) { - TT_FATAL( - ashape[i] == bshape[i] && - "bmm (non-bcast matmul) expects input tensors of shapes BCMK*BCKN=BCMN or equivalent"); - } - } TT_FATAL(in0_buffer->size() % in0_single_tile_size == 0); TT_FATAL(in1_buffer->size() % in1_single_tile_size == 0); diff --git a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp b/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp index 6b76abebbbf..2fd60cce636 100644 --- a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp @@ -1152,12 +1152,6 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_dram_sharded_optimized_( uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format); tt_metal::Buffer* in0_buffer = a.buffer(); tt_metal::Buffer* in1_buffer = b.buffer(); - TT_FATAL(ashape.rank() == bshape.rank() && ashape.rank() >= 2 && "bmm (non-bcast matmul) expects input tensors of the same rank and must have rank >= 2"); - for (auto i = 0; i < ashape.rank() - 2; i++) { - TT_FATAL( - ashape[i] == bshape[i] && - "bmm (non-bcast matmul) expects input tensors of shapes BCMK*BCKN=BCMN or equivalent"); - } TT_FATAL(in0_buffer->size() % in0_single_tile_size == 0); TT_FATAL(in1_buffer->size() % in1_single_tile_size == 0); diff --git a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp b/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp index aa38b30cf64..8ccc6545453 100644 --- a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_optimized/bmm_op_multi_core_reuse_optimized.cpp @@ -463,19 +463,6 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_optimized_(const Tensor uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format); tt_metal::Buffer *in0_buffer = a.buffer(); tt_metal::Buffer *in1_buffer = b.buffer(); - if (bcast_batch) - TT_FATAL( - get_batch_size(bshape) == 1 && - "matmul (batch bcast variant) expects input tensors of shapes BCMK*11KN=BCMN or equivalent"); - else { - // same condition as above, different message - TT_FATAL(ashape.rank() == bshape.rank() && "bmm (non-bcast matmul) expects input tensors of the same rank"); - for (auto i = 0; i < ashape.rank() - 2; i++) { - TT_FATAL( - ashape[i] == bshape[i] && - "bmm (non-bcast matmul) expects input tensors of shapes BCMK*BCKN=BCMN or equivalent"); - } - } MathFidelity math_fidelity; bool math_approx_mode;