Skip to content

Commit

Permalink
#16487: Optimize upsample for bilinear mode
Browse files Browse the repository at this point in the history
For input configurations of upsample of test_segformer_decode_head.py, perf improvement is above 50%.

Signed-off-by: Nilaykumar Patel <nkpatel@tenstorrent.com>
  • Loading branch information
nkpatel-tt authored Jan 17, 2025
1 parent 208eefc commit 58fb827
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 52 deletions.
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/operations/test_upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def test_bilinear_multi_core(
max_nshards = min(batch_size * height * width, max_grid_size[0] * max_grid_size[1])
nshards = max_nshards
while nshards > 0:
if batch_size * height * width % (nshards * TILE_WIDTH) == 0:
if batch_size * height % (nshards) == 0:
break
nshards -= 1
ncores = nshards
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,9 @@
#include "compute_kernel_api/reduce.h"
#include "compute_kernel_api/pack_untilize.h"

template <uint32_t in_ntiles_hw, uint32_t in_ntiles_c, uint32_t out_ntiles_c, uint32_t unpA_face_r_dim>
template <uint32_t in_ntiles_c, uint32_t out_ntiles_c, uint32_t unpA_face_r_dim>
inline void reduce_h_fused(
const uint32_t in_cb_id,
const uint32_t in_scalar_cb_id,
const uint32_t in_ntiles_hwc,
const uint32_t in_stick_index,
const uint32_t out_cb_id) {
const uint32_t in_cb_id, const uint32_t in_scalar_cb_id, const uint32_t in_ntiles_hwc, const uint32_t out_cb_id) {
cb_reserve_back(out_cb_id, 1);
tile_regs_acquire();
cb_wait_front(in_cb_id, 4);
Expand All @@ -40,30 +36,30 @@ inline void reduce_h_fused(

namespace NAMESPACE {
void MAIN {
constexpr uint32_t out_cb_id = tt::CBIndex::c_16;
constexpr uint32_t in1_cb_id = tt::CBIndex::c_1;
constexpr uint32_t bias_cb_id = tt::CBIndex::c_2;
constexpr uint32_t in_scalar_cb_id = tt::CBIndex::c_4;
constexpr uint32_t in2_cb_id = tt::CBIndex::c_24;
constexpr uint32_t in_cb_id1 = get_compile_time_arg_val(0);
constexpr uint32_t in_cb_id2 = get_compile_time_arg_val(1);
constexpr uint32_t in_scalar_cb_id1 = get_compile_time_arg_val(2);
constexpr uint32_t in_scalar_cb_id2 = get_compile_time_arg_val(3);
constexpr uint32_t out_cb_id = get_compile_time_arg_val(4);

constexpr uint32_t in_ntiles_hw = get_compile_time_arg_val(0);
constexpr uint32_t in_ntiles_c = get_compile_time_arg_val(1);
constexpr uint32_t in_ntiles_hwc = get_compile_time_arg_val(2);
constexpr uint32_t window_size_hw = get_compile_time_arg_val(3);
constexpr uint32_t out_h = get_compile_time_arg_val(4);
constexpr uint32_t out_w = get_compile_time_arg_val(5);
constexpr uint32_t out_ntiles_c = get_compile_time_arg_val(7);
constexpr uint32_t in_ntiles_c = get_compile_time_arg_val(5);
constexpr uint32_t in_ntiles_hwc = get_compile_time_arg_val(6);
constexpr uint32_t window_size_hw = get_compile_time_arg_val(7);
constexpr uint32_t out_ntiles_c = get_compile_time_arg_val(8);
constexpr uint32_t nsticks_per_core_by_nblocks = get_compile_time_arg_val(9);

constexpr uint32_t nsticks_per_core_by_nblocks = get_compile_time_arg_val(8);
constexpr uint32_t num_output_tiles = out_ntiles_c; //* nblocks;

tilizeA_B_reduce_init<false, true>(in1_cb_id, in_scalar_cb_id, in_ntiles_hwc, out_cb_id, 2, 4);
tilizeA_B_reduce_init<false, true>(in_cb_id1, in_scalar_cb_id1, in_ntiles_hwc, out_cb_id, 2, 4);
pack_untilize_dst_init_short<num_output_tiles>(out_cb_id, 1, 2); /* pack 1 row (1x16 or 1x32) */
for (uint32_t i = 0; i < nsticks_per_core_by_nblocks; i++) {
cb_wait_front(in_scalar_cb_id, 1);
reduce_h_fused<in_ntiles_hw, in_ntiles_c, out_ntiles_c, window_size_hw>(
in1_cb_id, in_scalar_cb_id, in_ntiles_hwc, i, out_cb_id);
cb_pop_front(in_scalar_cb_id, 1);
const uint32_t cb_id = (i % 2 == 0) ? in_cb_id1 : in_cb_id2;
const uint32_t scalar_cb_id = (i % 2 == 0) ? in_scalar_cb_id1 : in_scalar_cb_id2;

// Wait for the core to push data in cb
cb_wait_front(scalar_cb_id, 1);
reduce_h_fused<in_ntiles_c, out_ntiles_c, window_size_hw>(cb_id, scalar_cb_id, in_ntiles_hwc, out_cb_id);
cb_pop_front(scalar_cb_id, 1);
}
} // MAIN
} // namespace NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,40 @@ void kernel_main() {
uint32_t src1_addr = get_arg_val<uint32_t>(6);
uint32_t read_offset = get_arg_val<uint32_t>(8);
uint32_t is_last_row = get_arg_val<uint32_t>(9);
uint32_t in_h = 1;
constexpr bool src1_is_dram = false;

constexpr uint32_t in_cb_id = get_compile_time_arg_val(0);
constexpr uint32_t out_cb_id = tt::CBIndex::c_1;
constexpr uint32_t out_cb_id = get_compile_time_arg_val(1);
constexpr uint32_t in_scalar_cb_id = get_compile_time_arg_val(2);
// constexpr uint32_t is_reader = get_compile_time_arg_val(2);
constexpr uint32_t scale_h_inv_comp = get_compile_time_arg_val(3);
constexpr uint32_t scale_w_inv_comp = get_compile_time_arg_val(4);
constexpr uint32_t y_index_comp = get_compile_time_arg_val(5);
constexpr uint32_t x_index_compute_comp = get_compile_time_arg_val(6);
constexpr uint32_t is_reader = get_compile_time_arg_val(7);

uint32_t l1_read_addr = get_read_ptr(in_cb_id);
constexpr uint32_t in_scalar_cb_id = tt::CBIndex::c_4;

uint32_t total_nsticks_to_process = in_w * scale_w;
// Calculate the number of sticks to process per core by dividing the total number of sticks (in width direction)
// by 2.
uint32_t nsticks_to_process_on_core =
(total_nsticks_to_process % 2) ? total_nsticks_to_process / 2 + 1 : total_nsticks_to_process / 2;
// assuming shard begins with a new row. TODO: generalize?
float scale_h_inv = uint32_to_float(scale_h_inv_comp);
float scale_w_inv = uint32_to_float(scale_w_inv_comp);
float x, y, x_index, y_index, dx, dy;
y_index = uint32_to_float(y_index_comp);
float x_index_compute = uint32_to_float(x_index_compute_comp);

// If the current core is a writer core, adjust the x_index_compute to start from the correct position.
if (!is_reader) {
x_index_compute += scale_w_inv;
// If the total number of sticks is odd, process one less stick.
nsticks_to_process_on_core =
(total_nsticks_to_process % 2) ? nsticks_to_process_on_core - 1 : nsticks_to_process_on_core;
}
for (uint32_t image_row = 0; image_row < in_image_rows_per_core * scale_h; ++image_row) {
x_index = x_index_compute;
for (uint32_t j = 0; j < in_w * scale_w; j++) {
for (uint32_t j = 0; j < nsticks_to_process_on_core; j++) {
cb_reserve_back(out_cb_id, 4);
cb_reserve_back(in_scalar_cb_id, 1);

Expand Down Expand Up @@ -107,7 +118,7 @@ void kernel_main() {
noc_async_read_barrier();
cb_push_back(out_cb_id, 4);
cb_push_back(in_scalar_cb_id, 1);
x_index += scale_w_inv;
x_index += scale_w_inv * 2;
}
y_index += scale_h_inv;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ operation::ProgramWithCallbacks bilinear_multi_core(
auto halo_shard_shape = halo_in.shard_spec().value().shape;

// CBs
uint32_t buffering_factor = 1; // data is already fully buffered in the CBs since its sharded
uint32_t buffering_factor = 2;

// input data is in a sharded CB
uint32_t in_cb_id = CBIndex::c_0;
Expand All @@ -152,24 +152,39 @@ operation::ProgramWithCallbacks bilinear_multi_core(
auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config);

// intermediate tensor CB
uint32_t in1_cb_id = CBIndex::c_1;
uint32_t in_cb_id1 = CBIndex::c_1;
CircularBufferConfig cb_src1_config =
CircularBufferConfig(
4 * in_cb_pagesize, // since 4 pixels per page are needed for intermediate tensor.
{{in1_cb_id, input_cb_data_format}})
.set_page_size(in1_cb_id, in_cb_pagesize);
4 * in_cb_pagesize * buffering_factor, // since 4 pixels per page are needed for intermediate tensor.
{{in_cb_id1, input_cb_data_format}})
.set_page_size(in_cb_id1, in_cb_pagesize);
auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src1_config);

// intermediate tensor CB
uint32_t in_cb_id2 = CBIndex::c_2;
CircularBufferConfig cb_src2_config =
CircularBufferConfig(
4 * in_cb_pagesize * buffering_factor, // since 4 pixels per page are needed for intermediate tensor.
{{in_cb_id2, input_cb_data_format}})
.set_page_size(in_cb_id2, in_cb_pagesize);
auto cb_src2 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src2_config);

// scaler CB
uint32_t in_scalar_cb_id = CBIndex::c_4;
uint32_t in_scalar_cb_pagesize = tile_size(input_cb_data_format);
uint32_t in_scalar_cb_npages = 1;
CircularBufferConfig in_scalar_cb_config =
CircularBufferConfig(in_scalar_cb_npages * in_scalar_cb_pagesize, {{in_scalar_cb_id, input_cb_data_format}})
.set_page_size(in_scalar_cb_id, in_scalar_cb_pagesize);
uint32_t in_scalar_cb_npages = 1 * buffering_factor;
uint32_t in_scalar_cb_id1 = CBIndex::c_4;
CircularBufferConfig in_scalar_cb_config1 =
CircularBufferConfig(in_scalar_cb_npages * in_scalar_cb_pagesize, {{in_scalar_cb_id1, input_cb_data_format}})
.set_page_size(in_scalar_cb_id1, in_scalar_cb_pagesize);

auto in_scalar_cb1 = tt_metal::CreateCircularBuffer(program, all_cores, in_scalar_cb_config1);

auto in_scalar_cb = tt_metal::CreateCircularBuffer(program, all_cores, in_scalar_cb_config);
uint32_t in_scalar_cb_id2 = CBIndex::c_5;
CircularBufferConfig in_scalar_cb_config2 =
CircularBufferConfig(in_scalar_cb_npages * in_scalar_cb_pagesize, {{in_scalar_cb_id2, input_cb_data_format}})
.set_page_size(in_scalar_cb_id2, in_scalar_cb_pagesize);

auto in_scalar_cb2 = tt_metal::CreateCircularBuffer(program, all_cores, in_scalar_cb_config2);
// output sharded CB with upsampled data
uint32_t out_cb_id = CBIndex::c_16;
uint32_t aligned_output_stick_nbytes = round_up_to_mul32(output_stick_nbytes);
Expand Down Expand Up @@ -205,36 +220,52 @@ operation::ProgramWithCallbacks bilinear_multi_core(

std::vector<uint32_t> reader_compile_time_args = {
in_cb_id,
out_cb_id,
false,
in_cb_id1,
in_scalar_cb_id1,
scale_h_inv_u32,
scale_w_inv_u32,
y_index_u32,
x_index_compute_u32,
1,
};

std::vector<uint32_t> writer_compile_time_args = {
in_cb_id,
in_cb_id2,
in_scalar_cb_id2,
scale_h_inv_u32,
scale_w_inv_u32,
y_index_u32,
x_index_compute_u32,
0,
};

string writer_kernel_fname, reader_kernel_fname, compute_kernel_fname;

reader_kernel_fname = std::string(
"ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/reader_bilinear_multi_core_sharded.cpp");
writer_kernel_fname = std::string(
"ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/reader_bilinear_multi_core_sharded.cpp");
compute_kernel_fname = std::string("ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/compute/bilinear.cpp");

uint32_t in_ntiles_c = (uint32_t)std::ceil((float)input_shape[3] / constants::TILE_WIDTH);
std::vector<uint32_t> compute_compile_time_args = {
1,
in_cb_id1,
in_cb_id2,
in_scalar_cb_id1,
in_scalar_cb_id2,
out_cb_id,
in_ntiles_c,
1 * in_ntiles_c,
4,
output_shape[1],
output_shape[2],
(uint32_t)std::ceil((float)output_shape[2] / constants::TILE_HEIGHT),
scale_factor_h * scale_factor_w,
(uint32_t)std::ceil((float)output_shape[3] / constants::TILE_WIDTH),
output_nsticks_per_core, // loop count with blocks
input_shape[3],
};

auto reader_kernel =
CreateKernel(program, reader_kernel_fname, all_cores, ReaderDataMovementConfig(reader_compile_time_args));
auto writer_kernel =
CreateKernel(program, writer_kernel_fname, all_cores, WriterDataMovementConfig(writer_compile_time_args));
TT_FATAL(fp32_dest_acc_en == false, "fp32_dest_acc_en as true not supported. #12787 issue raised");
auto reduce_op = ReduceOpMath::SUM;
auto reduce_dim = ReduceOpDim::H;
Expand Down Expand Up @@ -267,13 +298,14 @@ operation::ProgramWithCallbacks bilinear_multi_core(
reader_rt_args[8] = (core == 0) ? 1 : 0;
reader_rt_args[9] = (core == ncores_nhw - 1) ? 1 : 0;
SetRuntimeArgs(program, reader_kernel, core_coord, reader_rt_args);
SetRuntimeArgs(program, writer_kernel, core_coord, reader_rt_args);
start_input_stick_id += input_nsticks_per_core;
}
} else {
TT_FATAL(false, "Unsupported memory layout");
}

auto override_runtime_args_callback = [reader_kernel, cb_src0, out_cb](
auto override_runtime_args_callback = [reader_kernel, writer_kernel, cb_src0, out_cb](
const void* operation,
Program& program,
const std::vector<Tensor>& input_tensors,
Expand Down

0 comments on commit 58fb827

Please sign in to comment.