From 667a678c604b75efff1f50d8842fe71209fbd3cc Mon Sep 17 00:00:00 2001 From: Borys Bradel <164946524+bbradelTT@users.noreply.github.com> Date: Sat, 12 Oct 2024 09:54:58 -0400 Subject: [PATCH] #13080: add out subblock count validation for matmul (#13347) * #13080: add out subblock count validation for matmul half dst mode * #13080: add method to get dest reg count for compute kernel config * #13080: include tile size when calculating dest register count and move output tile info creation * #13080: add checks for matmul that tile width and height are greater than 0 --- .../compute_kernel/compute_kernel_config.cpp | 36 +++++++++++++++ .../compute_kernel/compute_kernel_config.hpp | 2 + .../operations/matmul/device/matmul_op.cpp | 46 +++++++++++++------ 3 files changed, 71 insertions(+), 13 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.cpp b/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.cpp index 32a95ef60be..c68707711fd 100644 --- a/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.cpp +++ b/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.cpp @@ -2,9 +2,12 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "tt_metal/common/constants.hpp" #include "compute_kernel_config.hpp" #include "ttnn/device.hpp" +#define DATUMS_PER_ROW 16 + namespace ttnn { DeviceComputeKernelConfig init_device_compute_kernel_config( @@ -138,4 +141,37 @@ std::tuple get_compute_kernel_config_args( return std::make_tuple(math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en); } +uint32_t get_dest_reg_count(const DeviceComputeKernelConfig& compute_kernel_config, std::optional> tile_shape) { + + uint32_t tile_height; + uint32_t tile_width; + if (tile_shape.has_value()) { + std::array& shape = tile_shape.value(); + tile_height = shape[0]; + tile_width = shape[1]; + } else { + tile_height = tt::constants::TILE_HEIGHT; + tile_width = tt::constants::TILE_WIDTH; + } + // Note: if DATUMS_PER_ROW will change in a future architecture, then + // this code will need to be updated to use an architecture specific value. + uint32_t available_reg_count = (DEST_REGISTER_FULL_SIZE * DATUMS_PER_ROW) / (tile_width * tile_height); + std::visit( + [&](auto&& compute_kernel_config) { + using T = std::decay_t; + if (!compute_kernel_config.dst_full_sync_en) { + available_reg_count /= 2; + } + if constexpr (std::is_same_v) { + // Note: using bfloat16 as baseline to be conservative, even + // if smaller formats could have a larger register count. + if (compute_kernel_config.fp32_dest_acc_en) { + available_reg_count /= 2; + } + } + }, + compute_kernel_config); + return available_reg_count; +} + } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.hpp b/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.hpp index e12809a2f5c..2ec7a37f89a 100644 --- a/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.hpp +++ b/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.hpp @@ -44,4 +44,6 @@ MathFidelity get_math_fidelity(const std::optional& c std::tuple get_compute_kernel_config_args(tt::ARCH arch, const DeviceComputeKernelConfig compute_kernel_config); +uint32_t get_dest_reg_count(const DeviceComputeKernelConfig& compute_kernel_config, std::optional> tile_shape=std::nullopt); + } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp index 0e20ec28aba..467f903bea6 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp @@ -811,6 +811,23 @@ inline MatmulProgramConfig get_program_config( return config; } +tt::tt_metal::Tile get_output_tile(const MemoryConfig& output_mem_config, const std::array& in0_tile_shape, const std::array& in1_tile_shape, const std::optional output_tile) { + if (output_tile.has_value()) { + const auto& out_tile_shape = output_tile->get_tile_shape(); + TT_FATAL(out_tile_shape[1] > 0, "the override output tile width needs to be greater than zero"); + TT_FATAL(out_tile_shape[1] % in1_tile_shape[1] == 0, "the override output tile width be multiple of in1 tile width"); + TT_FATAL(out_tile_shape[0] > 0, "the override output tile height needs to be greater than zero"); + TT_FATAL(out_tile_shape[0] == in0_tile_shape[0], "the override output tile height must equal to the in0 tile height"); + if (out_tile_shape[1] != in1_tile_shape[1]) { + TT_FATAL(out_tile_shape[0] <= constants::FACE_HEIGHT, "the override output tile height must equal or less to face height"); + } + if (!output_mem_config.is_sharded()) { + TT_FATAL(out_tile_shape[1] == in1_tile_shape[1], "the override output tile width must equal to the in0 tile width"); + } + } + return output_tile.value_or(tt::tt_metal::Tile({in0_tile_shape[0], in1_tile_shape[1]})); +} + } // namespace namespace bmm_op_utils { @@ -887,6 +904,10 @@ Matmul create_matmul_struct( bool broadcast_batch = parameters.bcast_batch.value_or(get_broadcast_batch(input_tensor_a, input_tensor_b, parameters.program_config)); TT_FATAL(!(has_user_grid && has_program_config), "Cannot use both user core grid/coordinates and a program config"); + const auto& in0_tile_shape = input_tensor_a.get_tile().get_tile_shape(); + const auto& in1_tile_shape = input_tensor_b.get_tile().get_tile_shape(); + tt::tt_metal::Tile output_tile = get_output_tile( + parameters.output_mem_config, in0_tile_shape, in1_tile_shape, parameters.output_tile); return Matmul{ parameters.program_config, @@ -900,7 +921,7 @@ Matmul create_matmul_struct( parameters.user_run_batched, parameters.transpose_a, parameters.transpose_b, - parameters.output_tile}; + output_tile}; } Tensor matmul( @@ -974,7 +995,8 @@ void Matmul::validate( a_shape[-1], b_shape[-2]); - TT_FATAL(this->bcast_batch.has_value(), "Error"); + TT_FATAL(this->bcast_batch.has_value(), "Error: bcast_batch field should have been automatically populated"); + TT_FATAL(this->output_tile.has_value(), "Error: output_tile field should have been automatically populated"); if (this->bcast_batch.value()) { TT_FATAL( get_batch_size(b_shape) == 1, @@ -1280,6 +1302,14 @@ void Matmul::validate( TT_FATAL( program_config.per_core_N % program_config.out_subblock_w == 0, "per_core_N must be divisible by out_subblock_w"); + uint32_t available_reg_count = ttnn::get_dest_reg_count( + this->compute_kernel_config.value(), this->output_tile.value().get_tile_shape()); + TT_FATAL( + (program_config.out_subblock_w * program_config.out_subblock_h) <= available_reg_count, + "out_subblock_w {} times out_subblock_h {} needs to be at most {} to fit in hardware", + program_config.out_subblock_w, + program_config.out_subblock_h, + available_reg_count); } }, chosen_program_config); @@ -1315,17 +1345,7 @@ std::vector Matmul::create_output_tensors(const std::vector& inp const auto& input_tensor_b = input_tensors.at(1); auto in0_tile_shape = input_tensor_a.get_tile().get_tile_shape(); auto in1_tile_shape = input_tensor_b.get_tile().get_tile_shape(); - if (this->output_tile.has_value()) { - TT_FATAL(this->output_tile->get_tile_shape()[1] % in1_tile_shape[1] == 0, "the override output tile width be multiple of in1 tile width"); - TT_FATAL(this->output_tile->get_tile_shape()[0] == in0_tile_shape[0], "the override output tile height must equal to the in0 tile height"); - if (this->output_tile->get_tile_shape()[1] != in1_tile_shape[1]) { - TT_FATAL(this->output_tile->get_tile_shape()[0] <= constants::FACE_HEIGHT, "the override output tile height must equal or less to face height"); - } - if (!this->output_mem_config.is_sharded()) { - TT_FATAL(this->output_tile->get_tile_shape()[1] == in1_tile_shape[1], "the override output tile width must equal to the in0 tile width"); - } - } - auto output_tile = this->output_tile.value_or(tt::tt_metal::Tile({in0_tile_shape[0], in1_tile_shape[1]})); + auto output_tile = this->output_tile.value(); auto tile_width_ratio = output_tile.get_tile_shape()[1] / in1_tile_shape[1]; auto output_layout = this->untilize_out ? Layout::ROW_MAJOR : Layout::TILE; TT_FATAL(this->output_dtype.has_value(), "Error");