Skip to content

Commit

Permalink
#9142: BH -> Fix pack api
Browse files Browse the repository at this point in the history
  • Loading branch information
rtawfik01 committed Jun 6, 2024
1 parent 19182ad commit 29fc7f3
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 61 deletions.
16 changes: 16 additions & 0 deletions tests/tt_metal/test_utils/stimulus.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ std::vector<ValueType> generate_strided_vector(
return results;
}

template <typename ValueType>
std::vector<ValueType> generate_constant_vector(
const ValueType& constant, const size_t& numel) {
std::vector<ValueType> results(numel);
for (unsigned int index = 0; index < numel; index+=1) {
results.at(index) = constant;
}
return results;
}

template <typename ValueType>
std::vector<ValueType> generate_uniform_random_vector(
ValueType min, ValueType max, const size_t numel, const float seed = 0) {
Expand Down Expand Up @@ -103,5 +113,11 @@ std::vector<PackType> generate_packed_strided_vector(
return pack_vector<PackType, ValueType>(generate_strided_vector(init, assigned, stride, offset, numel));
}

template <typename PackType, typename ValueType>
std::vector<PackType> generate_packed_constant_vector(
const ValueType& constant, const size_t& numel) {
return pack_vector<PackType, ValueType>(generate_constant_vector(constant, numel));
}

} // namespace test_utils
} // namespace tt
119 changes: 59 additions & 60 deletions tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_pack_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,29 @@

template <bool untilize = false, bool zero_output = false>
inline void llk_pack_mop_config(const uint32_t output) {

const std::uint32_t output_id = get_output_id(output);
const std::uint32_t num_faces = get_output_num_faces(output_id);
const std::uint32_t face_r_dim = get_output_face_r_dim(output_id);
const std::uint32_t tile_c_dim = get_output_tile_c_dim(output_id);
const bool partial_face = get_output_partial_face(output_id) && IS_BFP_FORMAT((uint)pack_dst_format[output_id]);
const bool narrow_tile = get_output_narrow_tile(output_id);

_llk_pack_mop_config_<untilize, zero_output, DstTileFaceLayout::RowMajor, false>(
pack_dst_format[output_id], face_r_dim, num_faces, partial_face, narrow_tile);
pack_dst_format[output_id],
face_r_dim,
tile_c_dim,
num_faces,
partial_face,
narrow_tile
);
}

template <bool untilize = false, bool is_fp32_dest_acc_en = false>
inline void llk_pack_hw_configure(const llk_pack_params_t *pack_params) {
const std::uint32_t output_id = get_output_id(pack_params->pack_output);
const std::uint32_t face_r_dim = get_output_face_r_dim(output_id);
const std::uint32_t tile_c_dim = get_output_tile_c_dim(output_id);
const std::uint32_t num_faces = get_output_num_faces(output_id);
const bool partial_face = get_output_partial_face(output_id);
const bool narrow_tile = get_output_narrow_tile(output_id);
Expand All @@ -48,10 +57,12 @@ inline void llk_pack_hw_configure(const llk_pack_params_t *pack_params) {
pack_dst_format[output_id],
tile_size,
face_r_dim,
tile_c_dim,
num_faces,
partial_face,
narrow_tile,
pack_params->relu_config.val);
pack_params->relu_config.val
);
}

template <
Expand All @@ -66,14 +77,15 @@ inline void llk_pack_hw_configure_disaggregated(std::uint32_t pack_output) {
.f = {
.ApplyRelu = (std::uint32_t)relu_type,
.Threshold = relu_threshold,
}}};
}}};
llk_pack_hw_configure<untilize, is_fp32_dest_acc_en>(&llk_pack_params);
}

template <bool untilize = false, PoolType type, ReduceDim dim, bool is_fp32_dest_acc_en = false>
inline void llk_pack_reduce_hw_configure(const llk_pack_params_t *pack_params) {
const std::uint32_t output_id = get_output_id(pack_params->pack_output);
const std::uint32_t face_r_dim = get_output_face_r_dim(output_id);
const std::uint32_t tile_c_dim = get_output_tile_c_dim(output_id);
const std::uint32_t num_faces = get_output_num_faces(output_id);
const bool partial_face = get_output_partial_face(output_id);
const bool narrow_tile = get_output_narrow_tile(output_id);
Expand All @@ -85,10 +97,12 @@ inline void llk_pack_reduce_hw_configure(const llk_pack_params_t *pack_params) {
pack_dst_format[output_id],
tile_size,
face_r_dim,
tile_c_dim,
num_faces,
partial_face,
narrow_tile,
pack_params->relu_config.val);
pack_params->relu_config.val
);
}

template <
Expand All @@ -107,14 +121,22 @@ inline void llk_pack_reduce_hw_configure_disaggregated(std::uint32_t pack_output

template <bool untilize = false, bool zero_output = false>
inline void llk_pack_init(const std::uint32_t pack_output = 16) {

const std::uint32_t output_id = get_output_id(pack_output);
const std::uint32_t face_r_dim = get_output_face_r_dim(output_id);
const std::uint32_t tile_c_dim = get_output_tile_c_dim(output_id);
const std::uint32_t num_faces = get_output_num_faces(output_id);
const bool partial_face = get_output_partial_face(output_id);
const bool narrow_tile = get_output_narrow_tile(output_id);

_llk_pack_init_<untilize, zero_output, DstTileFaceLayout::RowMajor, false>(
pack_dst_format[output_id], face_r_dim, num_faces, partial_face, narrow_tile);
pack_dst_format[output_id],
face_r_dim,
tile_c_dim,
num_faces,
partial_face,
narrow_tile
);

// To untilize narrow tile (32x16) we just pack 2 faces back to back
// Number of datums to pack per row
Expand All @@ -129,26 +151,9 @@ inline std::uint32_t get_output_tile_address(std::uint8_t output_id, std::uint32
std::uint32_t pack_tile_addr;
if constexpr (out_of_order_output) {
pack_tile_addr = cb_interface[output_id].fifo_wr_ptr +
(std::uint32_t)(cb_interface[output_id].fifo_page_size) * output_tile_index - 1;
(std::uint32_t)(cb_interface[output_id].fifo_page_size)*output_tile_index - 1;
} else {
if constexpr (untilize) {
// FIXME: Need to support pack-untilize?
// std::uint16_t out_tile_index =
// (cb_interface[output_id].ublock_tile_cnt/cb_interface[output_id].ublock_ct)*cb_interface[output_id].row_tile_dim
// +
// cb_interface[output_id].ublock_tile_cnt%cb_interface[output_id].ublock_ct;
// //FIXME: optimize perf
// pack_tile_addr = cb_interface[output_id].fifo_wr_ptr + cb_interface[output_id].fifo_wr_tile_ptr - 1;
// pack_tile_addr += out_tile_index*(std::uint32_t)(cb_interface[output_id].fifo_page_size);

// cb_interface[output_id].ublock_tile_cnt++;

// if (cb_interface[output_id].ublock_tile_cnt == cb_interface[output_id].ublock_tile_dim) {
// cb_interface[output_id].ublock_tile_cnt=0;
// cb_interface[output_id].fifo_wr_tile_ptr +=
// (std::uint32_t)(cb_interface[output_id].fifo_page_size)*cb_interface[output_id].ublock_ct;
// }
} else {
if constexpr (!untilize) {
pack_tile_addr = cb_interface[output_id].fifo_wr_ptr + cb_interface[output_id].fifo_wr_tile_ptr - 1;
cb_interface[output_id].fifo_wr_tile_ptr += cb_interface[output_id].fifo_page_size;
}
Expand All @@ -164,51 +169,40 @@ inline void llk_pack(std::uint32_t tile_index, std::uint32_t output, std::uint32

std::uint32_t pack_tile_addr = get_output_tile_address<out_of_order_output, untilize>(output_id, output_tile_index);

_llk_pack_<DstSync::SyncHalf, untilize, is_fp32_dest_acc_en>(tile_index, pack_tile_addr);
_llk_pack_<DstSync::SyncHalf, untilize, is_fp32_dest_acc_en>(
tile_index,
pack_tile_addr
);
}

/*************************************************************************
* LLK PACK UNTILIZE
*************************************************************************/

template <std::uint32_t block_ct_dim = 8, std::uint32_t full_ct_dim = block_ct_dim, bool diagonal = false>
inline void llk_pack_untilize_init(
std::uint32_t output, const std::uint32_t face_r_dim = FACE_R_DIM, const std::uint32_t num_faces = 4) {
const std::uint32_t output_id = get_output_id(output);
template <std::uint32_t block_ct_dim = 8>
inline void llk_pack_untilize_init() {
_llk_pack_untilize_init_<block_ct_dim>();
}

_llk_pack_untilize_init_<block_ct_dim, full_ct_dim, diagonal>(pack_dst_format[output_id], face_r_dim, num_faces);

// Pack row by row
if constexpr (diagonal) {
TT_SETADCXX(p_setadc::PAC, 1 - 1, 0x0);
} else {
TT_SETADCXX(p_setadc::PAC, FACE_R_DIM - 1, 0x0);
}
}
template <std::uint32_t block_ct_dim = 8>
inline void llk_pack_untilize(std::uint32_t num_blocks, std::uint32_t output, const std::uint32_t face_r_dim = FACE_R_DIM, const std::uint32_t num_faces = 4, const std::uint32_t block_c_index = 0) {

template <std::uint32_t block_ct_dim = 8, std::uint32_t full_ct_dim = block_ct_dim, bool diagonal = false>
inline void llk_pack_untilize(
std::uint32_t block_rt_dim,
std::uint32_t output,
const std::uint32_t face_r_dim = FACE_R_DIM,
const std::uint32_t num_faces = 4,
const std::uint32_t block_c_index = 0) {
const std::uint32_t output_id = get_output_id(output);
std::uint32_t pack_tile_addr =
cb_interface[output_id].fifo_wr_ptr - 1 +
SCALE_DATUM_SIZE(
pack_dst_format[output_id],
(block_c_index * ((num_faces > 2) ? num_faces / 2 : num_faces) * block_ct_dim * FACE_C_DIM)) /
16;
std::uint32_t pack_tile_addr = cb_interface[output_id].fifo_wr_ptr - 1 + SCALE_DATUM_SIZE(pack_dst_format[output_id], (block_c_index * ((num_faces>1) ? num_faces/2 : 1) * block_ct_dim * FACE_R_DIM))/16;

for (std::uint32_t block=0; block<num_blocks; block++) {

for (std::uint32_t block_rt = 0; block_rt < block_rt_dim; block_rt++) {
_llk_pack_untilize_<block_ct_dim, full_ct_dim, diagonal>(
pack_tile_addr, pack_dst_format[output_id], face_r_dim, num_faces, block_rt * block_ct_dim);
_llk_pack_untilize_<block_ct_dim>(
pack_tile_addr,
pack_dst_format[output_id]
);

pack_tile_addr += full_ct_dim * cb_interface[output_id].fifo_page_size;
pack_tile_addr += block_ct_dim*cb_interface[output_id].fifo_page_size;
}
}


template <bool out_of_order_output = false, bool untilize = false, bool is_fp32_dest_acc_en = false>
inline void llk_matmul_pack(
std::uint32_t start_tile_index, std::uint32_t output, uint32_t ntiles, std::uint32_t output_tile_index = 0) {
Expand Down Expand Up @@ -240,14 +234,15 @@ inline void llk_pack_dest_section_done() {
_llk_pack_dest_section_done_<DstSync::SyncHalf, is_fp32_dest_acc_en>();
}

template <bool untilize = false, bool diagonal = false>
template <bool untilize = false>
inline void llk_init_packer_dest_offset_registers(const std::uint32_t pack_output = 16) {
const std::uint32_t output_id = get_output_id(pack_output);
const std::uint32_t face_r_dim = get_output_face_r_dim(output_id);
const bool narrow_tile = get_output_narrow_tile(output_id);

_llk_init_packer_dest_offset_registers_<DstSync::SyncHalf, DstTileFaceLayout::RowMajor, untilize, diagonal>(
face_r_dim, narrow_tile);
_llk_init_packer_dest_offset_registers_<DstSync::SyncHalf, DstTileFaceLayout::RowMajor, untilize>(
face_r_dim,
narrow_tile);
}

template <bool untilize = false, bool is_fp32_dest_acc_en = false>
Expand All @@ -257,7 +252,8 @@ inline void llk_pack_dest_init(const std::uint32_t pack_output = 16) {
const bool narrow_tile = get_output_narrow_tile(output_id);

_llk_pack_dest_init_<DstSync::SyncHalf, DstTileFaceLayout::RowMajor, untilize, is_fp32_dest_acc_en>(
face_r_dim, narrow_tile);
face_r_dim,
narrow_tile);
}

template <bool mail2math = true, bool mail2pack = true>
Expand All @@ -278,18 +274,21 @@ template <bool is_fp32_dest_acc_en = false, bool is_tile_dim_reconfig_en = false
inline void llk_pack_reconfig_data_format(const std::uint32_t new_output) {
const std::uint32_t output_id = get_output_id(new_output);
const std::uint32_t face_r_dim = get_output_face_r_dim(output_id);
const std::uint32_t tile_c_dim = get_output_tile_c_dim(output_id);
const std::uint32_t num_faces = get_output_num_faces(output_id);
const bool partial_face = get_output_partial_face(output_id);
const bool narrow_tile = get_output_narrow_tile(output_id);

_llk_pack_reconfig_data_format_<is_fp32_dest_acc_en, is_tile_dim_reconfig_en, DstTileFaceLayout::RowMajor>(
_llk_pack_reconfig_data_format_<is_fp32_dest_acc_en, is_tile_dim_reconfig_en, DstTileFaceLayout::RowMajor, false>(
pack_src_format[output_id],
pack_dst_format[output_id],
cb_interface[output_id].fifo_page_size,
face_r_dim,
tile_c_dim,
num_faces,
partial_face,
narrow_tile);
narrow_tile
);
}

template <bool is_fp32_dest_acc_en = false, bool is_tile_dim_reconfig_en = false>
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/third_party/tt_llk_blackhole

0 comments on commit 29fc7f3

Please sign in to comment.