Skip to content

Commit

Permalink
#11178: add sharding support to line reduce scatter (#13963)
Browse files Browse the repository at this point in the history
#11178: add sharding support to reduce scatter
  • Loading branch information
SeanNijjar authored Oct 18, 2024
1 parent bf48f8e commit a638d18
Show file tree
Hide file tree
Showing 8 changed files with 276 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from loguru import logger
import ttnn
from models.utility_functions import skip_for_grayskull
from tests.ttnn.unit_tests.operations.test_reduce_scatter_post_commit import run_reduce_scatter_test
from tests.ttnn.unit_tests.operations.test_reduce_scatter_post_commit import (
run_reduce_scatter_test,
run_reduce_scatter_sharded_test,
)


@skip_for_grayskull("Requires eth connected devices to run")
Expand Down Expand Up @@ -45,7 +48,7 @@
)
@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum])
@pytest.mark.parametrize("enable_async", [True])
def test_ring_reduce_scatter_post_commit(
def test_ring_reduce_scatter_n300_post_commit(
n300_mesh_device,
num_devices,
per_chip_output_shape,
Expand Down Expand Up @@ -75,3 +78,83 @@ def test_ring_reduce_scatter_post_commit(
num_iters=num_iters,
enable_async=enable_async,
)


@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.timeout(120)
@pytest.mark.parametrize(
"num_devices, num_links",
[
(2, 1),
],
)
@pytest.mark.parametrize("dim", [3])
@pytest.mark.parametrize("tensor_layout", [ttnn.TILE_LAYOUT])
@pytest.mark.parametrize("orientation", [ttnn.ShardOrientation.ROW_MAJOR])
@pytest.mark.parametrize(
"input_dtype",
[
ttnn.bfloat16,
],
)
@pytest.mark.parametrize(
"per_chip_output_shape,output_shard_shape,shard_grid,tensor_mem_layout",
(
(
(1, 1, 32, 1792),
(32, 32),
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 6))}),
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
),
(
(1, 1, 1792, 32),
(32, 32),
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 6))}),
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
),
(
(1, 1, 224, 256),
(32, 32),
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 6))}),
ttnn.TensorMemoryLayout.BLOCK_SHARDED,
),
),
)
@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum])
@pytest.mark.parametrize("enable_async", [True])
def test_width_sharded_reduce_scatter_N300_post_commit(
t3k_mesh_device,
num_devices,
per_chip_output_shape,
output_shard_shape,
dim,
num_links,
math_op,
shard_grid,
orientation,
input_dtype,
tensor_layout,
tensor_mem_layout,
use_program_cache,
function_level_defaults,
enable_async,
num_iters=5,
):
run_reduce_scatter_sharded_test(
t3k_mesh_device,
num_devices,
per_chip_output_shape,
output_shard_shape,
dim,
num_links,
math_op,
shard_grid,
orientation,
input_dtype,
tensor_layout,
tensor_mem_layout,
use_program_cache=use_program_cache,
function_level_defaults=function_level_defaults,
enable_async=enable_async,
num_iters=num_iters,
)
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ def run_reduce_scatter_sharded_test(
f"Not enough devices on machine to implement test case. Wanted {num_devices} but found {len(t3k_mesh_device.get_device_ids())}"
)

logger.info(f"Per chip output shape: {per_chip_output_shape}, devices: {num_devices}, scatter_dim: {scatter_dim}")

debug = False

t3k_mesh_device.enable_async(enable_async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,10 +455,9 @@ FORCE_INLINE void read_wrapped_chunk_from_output_tensor_to_address(
#ifdef INTERLEAVED_MEM_LAYOUT
uint64_t src_noc_addr = get_noc_addr(curr_page_idx, s);
noc_async_read(src_noc_addr, local_l1_read_addr, page_size);
#elif defined SHARDED_MEM_LAYOUT
#elif defined SHARDED_MEM_LAYOUT
ASSERT(false); // unimplemented
#endif
ASSERT(false); // unimplemented
#endif
#elif defined TILED_LAYOUT
#ifdef INTERLEAVED_MEM_LAYOUT
noc_async_read_tile(curr_page_idx, s, local_l1_read_addr);
Expand Down
135 changes: 78 additions & 57 deletions ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,38 @@ constexpr Shape4D<T> build_wrapped_row_tensor_slice(T n_pages) {
return Shape4D<T>{1, 1, 1, n_pages};
}

///////////////////////////////////////////////////
// COMPILE TIME ARGS
///////////////////////////////////////////////////

constexpr TensorMemoryLayout tensor_layout = static_cast<TensorMemoryLayout>(get_compile_time_arg_val(0));
constexpr BufferType buffer_type = static_cast<BufferType>(get_compile_time_arg_val(1));
constexpr Layout page_layout = static_cast<Layout>(get_compile_time_arg_val(2));
constexpr ttnn::ccl::EriscDataMoverTerminationMode termination_mode = static_cast<ttnn::ccl::EriscDataMoverTerminationMode>(get_compile_time_arg_val(3));
constexpr uint32_t cb_id = get_compile_time_arg_val(4);


#ifdef SHARDED_MEM_LAYOUT
static constexpr bool is_sharded_mode = true;
static constexpr uint32_t input_tensor_shard_grid_height = get_compile_time_arg_val(5);
static constexpr uint32_t input_tensor_shard_grid_width = get_compile_time_arg_val(6);
static constexpr uint32_t input_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(7);
static constexpr uint32_t input_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(8);
static constexpr uint32_t input_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(9);
static constexpr uint32_t input_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(10);
static constexpr bool input_tensor_shard_grid_transposed = get_compile_time_arg_val(11) != 0;
#else
static constexpr bool is_sharded_mode = false;
static constexpr uint32_t input_tensor_shard_grid_height = 0;
static constexpr uint32_t input_tensor_shard_grid_width = 0;
static constexpr uint32_t input_tensor_shard_grid_start_y_logical = 0;
static constexpr uint32_t input_tensor_shard_grid_start_x_logical = 0;
static constexpr uint32_t input_tensor_shard_pages_per_shard_y = 0;
static constexpr uint32_t input_tensor_shard_pages_per_shard_x = 0;
static constexpr bool input_tensor_shard_grid_transposed = false;
#endif


template <tt::tt_metal::TensorMemoryLayout tensor_layout, tt::tt_metal::BufferType buffer_type, tt::tt_metal::Layout page_layout>
auto build_source_address_generator(std::size_t &arg_idx, address_t tensor_address, std::size_t page_size, uint32_t cb_id_in0) -> typename source_tensor_addrgen<tensor_layout, buffer_type, page_layout>::type {
constexpr bool is_sharded = is_sharded_tensor_layout(tensor_layout);
Expand All @@ -126,56 +158,45 @@ auto build_source_address_generator(std::size_t &arg_idx, address_t tensor_addre

using addrgen_type = typename source_tensor_addrgen<tensor_layout, buffer_type, page_layout>::type;

if constexpr (is_row_major_layout) {
if constexpr (is_interleaved) {
if constexpr (tensor_layout == tt::tt_metal::TensorMemoryLayout::INTERLEAVED) {
if constexpr (is_row_major_layout) {
return addrgen_type{
.bank_base_address = tensor_address, .page_size = page_size};
} else if constexpr (is_sharded) {
return tt::tt_metal::address_generators::build_sharded_addr_gen<tensor_layout>(
tt::tt_metal::address_generators::HarvestedWormholeWorkerToNocLookup(
0,//output_shard_grid_nrows,
0,//output_shard_grid_row_map,
0,//output_shard_grid_ncols,
0),//output_shard_grid_col_map),
tt::tt_metal::address_generators::DeviceShardSpecTypeGetter<tensor_layout>::type(
0,//output_tensor_shard_pages_per_shard_y,
0,//output_tensor_shard_pages_per_shard_x,
0,//output_tensor_shard_grid_height,
0,//output_tensor_shard_grid_width,
0,//output_tensor_shard_grid_start_y_logical,
0,//output_tensor_shard_grid_start_x_logical,
0//output_tensor_shard_grid_transposed
),
page_size,
tensor_address
);
ASSERT(false); // unimplemented and untested
}
} else if constexpr (is_tile_page_layout) {
if constexpr (is_interleaved) {
} else {
return addrgen_type{
.bank_base_address = tensor_address, .page_size = page_size, .data_format = get_dataformat(cb_id_in0)};
} else if constexpr (is_sharded) {
ASSERT(false);//"Sharded support has not been added to ccl_send yet");
return tt::tt_metal::address_generators::build_sharded_addr_gen<tensor_layout>(
tt::tt_metal::address_generators::HarvestedWormholeWorkerToNocLookup(
0,//output_shard_grid_nrows,
0,//output_shard_grid_row_map,
0,//output_shard_grid_ncols,
0),//output_shard_grid_col_map),
tt::tt_metal::address_generators::DeviceShardSpecTypeGetter<tensor_layout>::type(
0,//output_tensor_shard_pages_per_shard_y,
0,//output_tensor_shard_pages_per_shard_x,
0,//output_tensor_shard_grid_height,
0,//output_tensor_shard_grid_width,
0,//output_tensor_shard_grid_start_y_logical,
0,//output_tensor_shard_grid_start_x_logical,
0//output_tensor_shard_grid_transposed
),
page_size,
tensor_address
);
}
} else if constexpr (
tensor_layout == tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED ||
tensor_layout == tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED ||
tensor_layout == tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED) {
size_t input_shard_grid_nrows = get_arg_val<uint32_t>(arg_idx++);
const auto * const input_shard_grid_row_map = reinterpret_cast<const uint32_t * const>(get_arg_addr(arg_idx));
arg_idx += input_shard_grid_nrows;
size_t input_shard_grid_ncols = get_arg_val<uint32_t>(arg_idx++);
const auto * const input_shard_grid_col_map = reinterpret_cast<const uint32_t * const>(get_arg_addr(arg_idx));
arg_idx += input_shard_grid_ncols;

return tt::tt_metal::address_generators::build_sharded_addr_gen<tensor_layout>(
tt::tt_metal::address_generators::HarvestedWormholeWorkerToNocLookup(
input_shard_grid_nrows,
input_shard_grid_row_map,
input_shard_grid_ncols,
input_shard_grid_col_map),
typename tt::tt_metal::address_generators::DeviceShardSpecTypeGetter<tensor_layout>::type(
input_tensor_shard_pages_per_shard_y,
input_tensor_shard_pages_per_shard_x,
input_tensor_shard_grid_height,
input_tensor_shard_grid_width,
input_tensor_shard_grid_start_y_logical,
input_tensor_shard_grid_start_x_logical,
input_tensor_shard_grid_transposed
),
page_size,
tensor_address
);
} else {
ASSERT(false);
}
}

Expand All @@ -190,12 +211,6 @@ void kernel_main() {
// ARGS
///////////////////////////////////////////////////

constexpr TensorMemoryLayout tensor_layout = static_cast<TensorMemoryLayout>(get_compile_time_arg_val(0));
constexpr BufferType buffer_type = static_cast<BufferType>(get_compile_time_arg_val(1));
constexpr Layout page_layout = static_cast<Layout>(get_compile_time_arg_val(2));
constexpr ttnn::ccl::EriscDataMoverTerminationMode termination_mode = static_cast<ttnn::ccl::EriscDataMoverTerminationMode>(get_compile_time_arg_val(3));
constexpr uint32_t cb_id = get_compile_time_arg_val(4);

// Load the input tensor spec
address_t tensor_address = get_arg_val<address_t>(arg_idx++);
address_t num_commands = get_arg_val<address_t>(arg_idx++);
Expand All @@ -209,7 +224,8 @@ void kernel_main() {
const uint32_t packet_size_in_pages = get_arg_val<uint32_t>(arg_idx++);
const uint32_t page_size = get_arg_val<uint32_t>(arg_idx++);
auto tensor_addrgen = build_source_address_generator<tensor_layout, buffer_type, page_layout>(arg_idx, tensor_address, page_size, tt::CB::c_in0);
volatile uint32_t* my_edm_worker_semaphore_ptr = reinterpret_cast<volatile uint32_t*>(get_semaphore(get_arg_val<uint32_t>(arg_idx++)));
auto semaphore_id = get_arg_val<uint32_t>(arg_idx++);
volatile uint32_t* my_edm_worker_semaphore_ptr = reinterpret_cast<volatile uint32_t*>(get_semaphore(semaphore_id));

// For now we only support single EDM connection
ccl::edm::WorkerToEdmSender<termination_mode> sender(
Expand All @@ -226,6 +242,11 @@ void kernel_main() {
// Instead, open up the CB and use it as a raw scratch space6
cb_reserve_back(cb_id, packet_size_in_pages);
const uint32_t local_l1_scratch_buffer_address = get_write_ptr(cb_id);

#ifdef DEBUG_PRINT_ENABLED
DPRINT << "ccl_send has " << (uint32_t)num_commands << " commands" << ENDL();
#endif

for (std::size_t i = 0; i < num_commands; ++i) {
// Generalized would be to get the command header info and then dispatch accordingly - if the command type is singular
//
Expand All @@ -235,7 +256,7 @@ void kernel_main() {

{
print_tensor_command(i, command_tensor);
ASSERT(ccl_command.worker_pages_per_slice > 0);
ASSERT(command_tensor.worker_pages_per_slice > 0);

// CURRENTLY ONLY SUPPORTS WRAPPED TENSOR ITERATION COMMANDS
// Implemented really inefficiently for now - in the future we can do more efficient packing and also change
Expand All @@ -252,14 +273,14 @@ void kernel_main() {
bool last_page_of_worker = false;
for (uint32_t p = 0; p < command_tensor.worker_pages_per_slice; p += packet_size_in_pages) {
uint32_t n_pages = std::min(packet_size_in_pages, command_tensor.worker_pages_per_slice - p);
ASSERT(ccl_command.worker_start_offset_in_slice.w == 1);
ASSERT(ccl_command.worker_start_offset_in_slice.z == 1);
ASSERT(command_tensor.worker_start_offset_in_slice.w == 0);
ASSERT(command_tensor.worker_start_offset_in_slice.z == 0);
ASSERT(valid_worker_slice_shape.w == 1);
ASSERT(valid_worker_slice_shape.z == 1);
ASSERT(command_tensor.tensor_shape.w == 1);
ASSERT(command_tensor.tensor_shape.z == 1);
ASSERT(ccl_command.tensor_slice_shape.w == 1);
ASSERT(ccl_command.tensor_slice_shape.z == 1);
ASSERT(command_tensor.tensor_slice_shape.w == 1);
ASSERT(command_tensor.tensor_slice_shape.z == 1);

read_wrapped_chunk_from_output_tensor_to_address(
curr_tile_id,
Expand Down
Loading

0 comments on commit a638d18

Please sign in to comment.