Skip to content

Commit

Permalink
#13080: add out subblock count validation for matmul (#13347)
Browse files Browse the repository at this point in the history
* #13080: add out subblock count validation for matmul half dst mode

* #13080: add method to get dest reg count for compute kernel config

* #13080: include tile size when calculating dest register count and move output tile info creation

* #13080: add checks for matmul that tile width and height are greater than 0
  • Loading branch information
bbradelTT authored Oct 12, 2024
1 parent d1f7385 commit 667a678
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "tt_metal/common/constants.hpp"
#include "compute_kernel_config.hpp"
#include "ttnn/device.hpp"

#define DATUMS_PER_ROW 16

namespace ttnn {

DeviceComputeKernelConfig init_device_compute_kernel_config(
Expand Down Expand Up @@ -138,4 +141,37 @@ std::tuple<MathFidelity, bool, bool, bool, bool> get_compute_kernel_config_args(
return std::make_tuple(math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en);
}

uint32_t get_dest_reg_count(const DeviceComputeKernelConfig& compute_kernel_config, std::optional<std::array<uint32_t, 2>> tile_shape) {

uint32_t tile_height;
uint32_t tile_width;
if (tile_shape.has_value()) {
std::array<uint32_t, 2>& shape = tile_shape.value();
tile_height = shape[0];
tile_width = shape[1];
} else {
tile_height = tt::constants::TILE_HEIGHT;
tile_width = tt::constants::TILE_WIDTH;
}
// Note: if DATUMS_PER_ROW will change in a future architecture, then
// this code will need to be updated to use an architecture specific value.
uint32_t available_reg_count = (DEST_REGISTER_FULL_SIZE * DATUMS_PER_ROW) / (tile_width * tile_height);
std::visit(
[&](auto&& compute_kernel_config) {
using T = std::decay_t<decltype(compute_kernel_config)>;
if (!compute_kernel_config.dst_full_sync_en) {
available_reg_count /= 2;
}
if constexpr (std::is_same_v<T, WormholeComputeKernelConfig>) {
// Note: using bfloat16 as baseline to be conservative, even
// if smaller formats could have a larger register count.
if (compute_kernel_config.fp32_dest_acc_en) {
available_reg_count /= 2;
}
}
},
compute_kernel_config);
return available_reg_count;
}

} // namespace ttnn
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,6 @@ MathFidelity get_math_fidelity(const std::optional<DeviceComputeKernelConfig>& c

std::tuple<MathFidelity, bool, bool, bool, bool> get_compute_kernel_config_args(tt::ARCH arch, const DeviceComputeKernelConfig compute_kernel_config);

uint32_t get_dest_reg_count(const DeviceComputeKernelConfig& compute_kernel_config, std::optional<std::array<uint32_t, 2>> tile_shape=std::nullopt);

} // namespace ttnn
46 changes: 33 additions & 13 deletions ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,23 @@ inline MatmulProgramConfig get_program_config(
return config;
}

tt::tt_metal::Tile get_output_tile(const MemoryConfig& output_mem_config, const std::array<uint32_t, 2>& in0_tile_shape, const std::array<uint32_t, 2>& in1_tile_shape, const std::optional<const tt::tt_metal::Tile> output_tile) {
if (output_tile.has_value()) {
const auto& out_tile_shape = output_tile->get_tile_shape();
TT_FATAL(out_tile_shape[1] > 0, "the override output tile width needs to be greater than zero");
TT_FATAL(out_tile_shape[1] % in1_tile_shape[1] == 0, "the override output tile width be multiple of in1 tile width");
TT_FATAL(out_tile_shape[0] > 0, "the override output tile height needs to be greater than zero");
TT_FATAL(out_tile_shape[0] == in0_tile_shape[0], "the override output tile height must equal to the in0 tile height");
if (out_tile_shape[1] != in1_tile_shape[1]) {
TT_FATAL(out_tile_shape[0] <= constants::FACE_HEIGHT, "the override output tile height must equal or less to face height");
}
if (!output_mem_config.is_sharded()) {
TT_FATAL(out_tile_shape[1] == in1_tile_shape[1], "the override output tile width must equal to the in0 tile width");
}
}
return output_tile.value_or(tt::tt_metal::Tile({in0_tile_shape[0], in1_tile_shape[1]}));
}

} // namespace

namespace bmm_op_utils {
Expand Down Expand Up @@ -887,6 +904,10 @@ Matmul create_matmul_struct(
bool broadcast_batch =
parameters.bcast_batch.value_or(get_broadcast_batch(input_tensor_a, input_tensor_b, parameters.program_config));
TT_FATAL(!(has_user_grid && has_program_config), "Cannot use both user core grid/coordinates and a program config");
const auto& in0_tile_shape = input_tensor_a.get_tile().get_tile_shape();
const auto& in1_tile_shape = input_tensor_b.get_tile().get_tile_shape();
tt::tt_metal::Tile output_tile = get_output_tile(
parameters.output_mem_config, in0_tile_shape, in1_tile_shape, parameters.output_tile);

return Matmul{
parameters.program_config,
Expand All @@ -900,7 +921,7 @@ Matmul create_matmul_struct(
parameters.user_run_batched,
parameters.transpose_a,
parameters.transpose_b,
parameters.output_tile};
output_tile};
}

Tensor matmul(
Expand Down Expand Up @@ -974,7 +995,8 @@ void Matmul::validate(
a_shape[-1],
b_shape[-2]);

TT_FATAL(this->bcast_batch.has_value(), "Error");
TT_FATAL(this->bcast_batch.has_value(), "Error: bcast_batch field should have been automatically populated");
TT_FATAL(this->output_tile.has_value(), "Error: output_tile field should have been automatically populated");
if (this->bcast_batch.value()) {
TT_FATAL(
get_batch_size(b_shape) == 1,
Expand Down Expand Up @@ -1280,6 +1302,14 @@ void Matmul::validate(
TT_FATAL(
program_config.per_core_N % program_config.out_subblock_w == 0,
"per_core_N must be divisible by out_subblock_w");
uint32_t available_reg_count = ttnn::get_dest_reg_count(
this->compute_kernel_config.value(), this->output_tile.value().get_tile_shape());
TT_FATAL(
(program_config.out_subblock_w * program_config.out_subblock_h) <= available_reg_count,
"out_subblock_w {} times out_subblock_h {} needs to be at most {} to fit in hardware",
program_config.out_subblock_w,
program_config.out_subblock_h,
available_reg_count);
}
},
chosen_program_config);
Expand Down Expand Up @@ -1315,17 +1345,7 @@ std::vector<Tensor> Matmul::create_output_tensors(const std::vector<Tensor>& inp
const auto& input_tensor_b = input_tensors.at(1);
auto in0_tile_shape = input_tensor_a.get_tile().get_tile_shape();
auto in1_tile_shape = input_tensor_b.get_tile().get_tile_shape();
if (this->output_tile.has_value()) {
TT_FATAL(this->output_tile->get_tile_shape()[1] % in1_tile_shape[1] == 0, "the override output tile width be multiple of in1 tile width");
TT_FATAL(this->output_tile->get_tile_shape()[0] == in0_tile_shape[0], "the override output tile height must equal to the in0 tile height");
if (this->output_tile->get_tile_shape()[1] != in1_tile_shape[1]) {
TT_FATAL(this->output_tile->get_tile_shape()[0] <= constants::FACE_HEIGHT, "the override output tile height must equal or less to face height");
}
if (!this->output_mem_config.is_sharded()) {
TT_FATAL(this->output_tile->get_tile_shape()[1] == in1_tile_shape[1], "the override output tile width must equal to the in0 tile width");
}
}
auto output_tile = this->output_tile.value_or(tt::tt_metal::Tile({in0_tile_shape[0], in1_tile_shape[1]}));
auto output_tile = this->output_tile.value();
auto tile_width_ratio = output_tile.get_tile_shape()[1] / in1_tile_shape[1];
auto output_layout = this->untilize_out ? Layout::ROW_MAJOR : Layout::TILE;
TT_FATAL(this->output_dtype.has_value(), "Error");
Expand Down

0 comments on commit 667a678

Please sign in to comment.