Skip to content

Commit

Permalink
#11962: cleanup new reconfig uses
Browse files Browse the repository at this point in the history
  • Loading branch information
rdjogoTT committed Oct 9, 2024
1 parent 6767452 commit 413c1d9
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void MAIN {

pack_untilize_uninit(untilized_cache_cb);

unpack_reconfig_data_format_srca(cache_cb, untilized_cache2_cb);
reconfig_data_format_srca(cache_cb, untilized_cache2_cb);
pack_reconfig_data_format(untilized_cache_cb, out_cb);

tilize_init_short(untilized_cache2_cb, Wt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ void MAIN {
#ifndef RMSNORM
// calculate var = E(x^2) - E(x)^2
// E(x)^2
unpack_reconfig_data_format(cb_stats_reduced, cb_stats_reduced);
reconfig_data_format(cb_stats_reduced, cb_stats_reduced);
cb_reserve_back(cb_ex_sqr, 1);
cb_wait_front(cb_stats_reduced, 1);
tile_regs_acquire();
Expand All @@ -152,8 +152,8 @@ void MAIN {


// E(x^2) - E(x)^2
unpack_reconfig_data_format_srca(cb_stats_reduced, cb_ex2);
unpack_reconfig_data_format_srcb(cb_stats_reduced, cb_ex_sqr);
reconfig_data_format_srca(cb_stats_reduced, cb_ex2);
reconfig_data_format_srcb(cb_stats_reduced, cb_ex_sqr);
pack_reconfig_data_format(cb_var);
cb_wait_front(cb_ex2, 1);
cb_wait_front(cb_ex_sqr, 1);
Expand All @@ -172,7 +172,7 @@ void MAIN {


// 1/[sqrt(Var + eps)],
unpack_reconfig_data_format(cb_var, cb_eps); // cb_var is cb_stats in case of RMS norm
reconfig_data_format(cb_var, cb_eps); // cb_var is cb_stats in case of RMS norm
pack_reconfig_data_format(cb_stats_reduced);
cb_wait_front(cb_var, 1);
cb_wait_front(cb_eps, 1);
Expand All @@ -198,7 +198,7 @@ void MAIN {

#ifndef RMSNORM
// x - E[x]
unpack_reconfig_data_format(cb_in0, cb_ex_global);
reconfig_data_format(cb_in0, cb_ex_global);
pack_reconfig_data_format(cb_xmm);
index_h_offset = 0;
sub_bcast_cols_init_short();
Expand Down Expand Up @@ -234,7 +234,7 @@ void MAIN {
}

// (x - Ex) * 1/[sqrt(Var + eps)]
unpack_reconfig_data_format(cb_xmm, cb_ex_global);
reconfig_data_format(cb_xmm, cb_ex_global);
mul_bcast_cols_init_short();
index_h_offset = 0;
cb_reserve_back(cb_im, num_tiles_per_block);
Expand Down Expand Up @@ -269,7 +269,7 @@ void MAIN {
cb_wait_front(cb_im, num_tiles_per_block);

if constexpr(do_gamma) {
unpack_reconfig_data_format(cb_im, cb_gamma);
reconfig_data_format(cb_im, cb_gamma);
if constexpr(do_beta == 0) {
pack_reconfig_data_format(cb_out);
}
Expand Down Expand Up @@ -301,7 +301,7 @@ void MAIN {
}

if constexpr(do_beta) {
unpack_reconfig_data_format(cb_fusion, cb_beta);
reconfig_data_format(cb_fusion, cb_beta);
pack_reconfig_data_format(cb_out);
add_bcast_rows_init_short();
cb_wait_front(cb_beta, block_w);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void MAIN {

cb_wait_front(cb_scaler, 1);
#ifndef RMSNORM
unpack_reconfig_data_format_srcb(cb_in0, cb_scaler);
reconfig_data_format_srcb(cb_in0, cb_scaler);
// E[x],
index_h_offset = 0;
reduce_init_delta<false>();
Expand All @@ -96,7 +96,7 @@ void MAIN {
}
reduce_revert_delta();
cb_push_back(cb_ex_partial2, block_h);
unpack_reconfig_data_format_srcb(cb_scaler, cb_in0);
reconfig_data_format_srcb(cb_scaler, cb_in0);
#endif // not RMSNORM

// X^2
Expand Down Expand Up @@ -124,8 +124,8 @@ void MAIN {
cb_push_back(cb_x2, num_tiles_per_block);

// E(x^2)
unpack_reconfig_data_format_srca(cb_in0, cb_x2);
unpack_reconfig_data_format_srcb(cb_in0, cb_scaler);
reconfig_data_format_srca(cb_in0, cb_x2);
reconfig_data_format_srcb(cb_in0, cb_scaler);

cb_wait_front(cb_x2, num_tiles_per_block);

Expand All @@ -152,8 +152,8 @@ void MAIN {
// global reduce, cb_ex <-- cb_ex_external2, cb_ex_partial2
if constexpr(is_allgather_worker) {
cb_wait_front(cb_scaler_global, 1);
unpack_reconfig_data_format_srca(cb_x2, cb_ex_external2);
unpack_reconfig_data_format_srcb(cb_scaler, cb_scaler_global);
reconfig_data_format_srca(cb_x2, cb_ex_external2);
reconfig_data_format_srcb(cb_scaler, cb_scaler_global);
reduce_init_delta<false>();
cb_reserve_back(cb_reduction_out, num_tiles_per_partial_result*num_tiles_per_allgather_worker);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ ALWI void REL() { release_dst(); }
void calc_numeric_stable(uint32_t Wt, uint32_t ndst, uint32_t cb_in, uint32_t cb_bcast_scaler, uint32_t cb_max, uint32_t cb_out) {
// calculate max val per row
ACQ();
unpack_reconfig_data_format(cb_in, cb_bcast_scaler);
reconfig_data_format(cb_in, cb_bcast_scaler);
cb_reserve_back(cb_max, 1);
cb_wait_front(cb_bcast_scaler, 1);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_ROW>();
Expand All @@ -45,7 +45,7 @@ void calc_numeric_stable(uint32_t Wt, uint32_t ndst, uint32_t cb_in, uint32_t cb

// calculate x-max(x)
exp_tile_init<EXP_APPROX>();
unpack_reconfig_data_format_srcb(cb_max);
reconfig_data_format_srcb(cb_max);
cb_wait_front(cb_max, 1);
sub_bcast_cols_init_short();
for (uint32_t wt = 0; wt < Wt; wt += ndst) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ template<uint32_t block_w, uint32_t num_subblocks_w, uint32_t subblock_w>
ALWI void calc_numeric_stable(uint32_t cb_in, uint32_t cb_bcast_scaler, uint32_t cb_max, uint32_t cb_out) {
// calculate max val per row
ACQ();
unpack_reconfig_data_format(cb_in, cb_bcast_scaler);
reconfig_data_format(cb_in, cb_bcast_scaler);
cb_reserve_back(cb_max, 1);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_ROW>();
cb_wait_front(cb_bcast_scaler, 1);
Expand All @@ -37,7 +37,7 @@ ALWI void calc_numeric_stable(uint32_t cb_in, uint32_t cb_bcast_scaler, uint32_t

// calculate x-max(x)
exp_tile_init<EXP_APPROX>();
unpack_reconfig_data_format_srcb(cb_max);
reconfig_data_format_srcb(cb_max);
cb_wait_front(cb_max, 1);
sub_bcast_cols_init_short();
uint32_t index_subblock_w_offset = 0;
Expand Down Expand Up @@ -175,8 +175,7 @@ void MAIN {
#ifdef NUMERIC_STABLE
calc_numeric_stable<block_w, num_subblocks_w, subblock_w>(cb_in0, cb_bcast_scaler, cb_max, cb_exps);
#else
unpack_reconfig_data_format(cb_in0, cb_in0);
math_reconfig_data_format(cb_in0, cb_in0);
reconfig_data_format(cb_in0, cb_in0);
pack_reconfig_data_format(cb_exps);
// exp(x)
index_subblock_w_offset = 0;
Expand All @@ -198,8 +197,7 @@ void MAIN {
index_subblock_w_offset += subblock_w;
}
cb_pop_front(cb_in0, block_w);
unpack_reconfig_data_format(cb_exps, cb_bcast_scaler);
math_reconfig_data_format(cb_exps, cb_bcast_scaler);
reconfig_data_format(cb_exps, cb_bcast_scaler);
#endif
#endif // FUSED_SCALE_MASK

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ void MAIN {
reduce_c<PoolType::MAX, ReduceDim::REDUCE_ROW, cb_qk_im, cb_identity_scale_in, cb_cur_max, Sq_chunk_t, Sk_chunk_t>();

if (k_chunk > k_chunk_start) {
unpack_reconfig_data_format(cb_cur_max, cb_prev_max);
reconfig_data_format(cb_cur_max, cb_prev_max);
max_block_inplace(cb_cur_max, cb_prev_max, Sq_chunk_t);
}
/* QK -= cb_cur_max */
Expand Down Expand Up @@ -598,7 +598,7 @@ void MAIN {
/* cb_cur_sum = 1.0 / cb_cur_sum */
cb_push_back(cb_cur_sum, Sq_chunk_t);

unpack_reconfig_data_format(cb_cur_sum, cb_cur_sum); // DEBUG
reconfig_data_format(cb_cur_sum, cb_cur_sum); // DEBUG
pack_reconfig_data_format(cb_cur_sum);
recip_block_inplace(cb_cur_sum, Sq_chunk_t);

Expand Down

0 comments on commit 413c1d9

Please sign in to comment.