Skip to content

Commit

Permalink
#0: Changed default layout to RM
Browse files Browse the repository at this point in the history
  • Loading branch information
sankarmanoj-tt committed Jan 23, 2025
1 parent 52542d4 commit bd63047
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
6 changes: 4 additions & 2 deletions tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def run_conv(
padded_input_channels=None,
fp32_accum=False,
packer_l1_acc=False,
output_layout=ttnn.TILE_LAYOUT,
output_layout=ttnn.ROW_MAJOR_LAYOUT,
deallocate_activation=False,
debug=False,
groups=1,
Expand All @@ -87,6 +87,8 @@ def run_conv(
else:
total_batch_size = batch_size

if output_layout == ttnn.ROW_MAJOR_LAYOUT and activations_dtype == ttnn.bfloat8_b:
pytest.skip("Row major layout not compatible with bfloat8_b")
torch.manual_seed(0)
conv_input_shape = [total_batch_size, input_channels, input_height, input_width]
conv_weight_shape = [output_channels, input_channels // groups, filter_height, filter_width]
Expand Down Expand Up @@ -1257,7 +1259,7 @@ def test_resnet50_conv_wh_fp32(
)
@pytest.mark.parametrize(
"activations_dtype",
[ttnn.bfloat8_b],
[ttnn.bfloat16],
)
@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi])
@pytest.mark.parametrize("enable_auto_formatting", [True, False])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,15 @@ void MAIN {
// DPRINT<<"out_subblock_num_tiles: "<<out_subblock_num_tiles<<ENDL();
// DPRINT<<"tilize_in0: "<<(uint32_t)tilize_in0<<ENDL();
// DPRINT<<"untilize_out: "<<(uint32_t)untilize_out<<ENDL();
DPRINT << "out_cb_id: " << out_cb_id << ENDL();
// DPRINT<<"out_cb_id: "<<out_cb_id<<ENDL();
// DPRINT<<"output_rows_h: "<<output_rows_h<<ENDL();
// DPRINT<<"is_non_tile_height: "<<(uint32_t)is_non_tile_height<<ENDL();
// #ifdef WIDTH_SHARDED
// DPRINT<<"in0_nblocks_w_tilize: "<<in0_nblocks_w_tilize<<ENDL();
// #endif
// DPRINT<<"out_block_num_tiles: "<<out_block_num_tiles<<ENDL();
// DPRINT<<"out_block_w: "<<out_block_w<<ENDL();
DPRINT << "spill: " << (uint32_t)spill << ENDL();
// DPRINT<<"spill: "<<(uint8_t)spill<<ENDL();
DPRINT << "untilize_mode_out_cb_id: " << untilize_mode_out_cb_id << ENDL();)
#ifdef FUSE_BIAS
constexpr uint32_t bias_ntiles_w = get_compile_time_arg_val(16);
Expand Down Expand Up @@ -178,7 +178,11 @@ void MAIN {
const bool use_partials_for_out = (partials_cb_read_ptr == get_local_cb_interface(out_cb_id).fifo_rd_ptr);)
PACK(uint32_t partials_cb_write_ptr = get_local_cb_interface(matmul_partials_cb).fifo_wr_ptr;
const bool use_partials_for_out = (partials_cb_write_ptr == get_local_cb_interface(out_cb_id).fifo_wr_ptr);)
DPRINT_UNPACK(DPRINT << "Saved Read Ptr: " << partials_cb_read_ptr << "\n";)
DPRINT_UNPACK(
DPRINT << "Saved Read Ptr: " << partials_cb_read_ptr << "\n"; if (use_partials_for_out) {
DPRINT << "Using Partials for Out\n";
} else { DPRINT << "Not Using Partials for Out\n"; })

DPRINT_PACK(DPRINT << "Saved CB Write Ptr: " << partials_cb_write_ptr << "\n";)

DPRINT_UNPACK(DPRINT << "MM Out CB Read Ptr: " << mm_out_cb_id << " "
Expand All @@ -201,12 +205,14 @@ void MAIN {
// for each output block we start we relu disabled so that intermediate results are not relu'd
PACK((llk_pack_relu_config(ReluType::NO_RELU)));
#endif
if (untilize_out == false) {
UNPACK(partials_cb_read_ptr = get_local_cb_interface(matmul_partials_cb).fifo_rd_ptr);
PACK(partials_cb_write_ptr = get_local_cb_interface(matmul_partials_cb).fifo_wr_ptr);
DPRINT_UNPACK(DPRINT << "Saved Read Ptr: " << partials_cb_read_ptr << "\n";)
DPRINT_PACK(DPRINT << "Saved CB Write Ptr: " << partials_cb_write_ptr << "\n";)
}
UNPACK(
if (use_partials_for_out) partials_cb_read_ptr =
get_local_cb_interface(matmul_partials_cb).fifo_rd_ptr);
PACK(
if (use_partials_for_out) partials_cb_write_ptr =
get_local_cb_interface(matmul_partials_cb).fifo_wr_ptr);
DPRINT_UNPACK(DPRINT << "Saved Read Ptr: " << partials_cb_read_ptr << "\n";)
DPRINT_PACK(DPRINT << "Saved CB Write Ptr: " << partials_cb_write_ptr << "\n";)
uint32_t curr_matmul_out_cb = matmul_partials_cb;
for (uint32_t in0_block_w_i = 0; in0_block_w_i < in0_num_blocks_w; ++in0_block_w_i) {
#ifdef WIDTH_SHARDED
Expand Down

0 comments on commit bd63047

Please sign in to comment.