Skip to content

Commit

Permalink
#13106: add optional memory config for output
Browse files Browse the repository at this point in the history
#13106: add optional memory config for output to maxpool
  • Loading branch information
mywoodstock committed Oct 16, 2024
1 parent d1450c7 commit f3c290a
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 30 deletions.
31 changes: 26 additions & 5 deletions tests/ttnn/unit_tests/operations/test_maxpool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def run_max_pool(
dilation,
device,
dtype,
memory_config=None,
):
in_n, in_c, in_h, in_w = act_shape
kernel_h, kernel_w = kernel_size
Expand Down Expand Up @@ -107,10 +108,9 @@ def run_max_pool(
stride=[stride_h, stride_w],
padding=[pad_h, pad_w],
dilation=[dilation_h, dilation_w],
memory_config=memory_config,
)

# interleaved_mem_config = ttnn.L1_MEMORY_CONFIG
# output = ttnn.to_memory_config(output, interleaved_mem_config)
output_host = output.cpu()
output_pytorch_padded = torch.Tensor(ttnn.to_torch(output_host))
output_pytorch = output_pytorch_padded[:, :, :, :in_c]
Expand All @@ -129,9 +129,6 @@ def run_max_pool(
golden_shape = golden_pytorch.shape
output_pytorch = output_pytorch.reshape(golden_shape[0], golden_shape[2], golden_shape[3], golden_shape[1])

# torch.save(output_pytorch, "output_pytorch.pt")
# torch.save(golden_pytorch, "golden_pytorch.pt")

output_pytorch = torch.permute(output_pytorch, (0, 3, 1, 2)) ## N, C, H, W
passing, pcc = assert_with_pcc(output_pytorch, golden_pytorch)

Expand All @@ -151,6 +148,10 @@ def run_max_pool(
if dtype == ttnn.bfloat16:
assert isequal

if memory_config:
logger.debug(f"Output memory config: {memory_config}")
assert ttnn.get_memory_config(output) == memory_config


@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -228,6 +229,26 @@ def test_run_max_pool(
run_max_pool(act_shape, kernel_size, padding, stride, dilation, device, dtype)


@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize(
"act_shape", ## NCHW
(
(
[8, 64, 112, 112],
[1, 512, 10, 10],
)
),
)
@pytest.mark.parametrize("memory_config", [ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG])
def test_run_max_pool_mem_config(
act_shape,
device,
memory_config,
use_program_cache,
):
run_max_pool(act_shape, (3, 3), (1, 1), (2, 2), (1, 1), device, ttnn.bfloat16, memory_config=memory_config)


@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize(
"act_shape", ## NCHW
Expand Down
68 changes: 68 additions & 0 deletions tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def run_conv(
has_bias=True,
shard_layout=None,
auto_shard=False,
memory_config=None,
):
torch.manual_seed(0)
conv_input_shape = [batch_size, input_channels, input_height, input_width]
Expand Down Expand Up @@ -165,6 +166,7 @@ def run_conv(
conv_op_cache=reader_patterns_cache,
debug=debug,
groups=groups,
memory_config=memory_config,
)

tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
Expand All @@ -188,6 +190,11 @@ def run_conv(
logger.info(f"PCC = {pcc_msg}. Threshold = {pcc}")
assert passing

if memory_config:
output_memory_config = ttnn.get_memory_config(tt_output_tensor_on_device)
logger.info(f"Output Memory Config : {output_memory_config}")
assert output_memory_config == memory_config


def run_conv_with_split(
device,
Expand Down Expand Up @@ -797,6 +804,67 @@ def test_resnet50_conv_wh(
)


@skip_for_grayskull()
@skip_for_blackhole()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize(
"batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override",
(
(16, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, True, {"act_block_h": 256}),
(8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None),
),
)
@pytest.mark.parametrize("memory_config", [ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG])
def test_conv_mem_config_wh(
device,
use_program_cache,
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
pad_h,
pad_w,
use_1d_systolic_array,
config_override,
memory_config,
):
if device.core_grid.y == 7:
pytest.skip("Issue #6992: Statically allocated circular buffers in program clash with L1 buffers on core range")

use_shallow_conv_variant = (input_channels == 16) and device.arch() != ttnn.device.Arch.WORMHOLE_B0
run_conv(
device,
ttnn.MathFidelity.LoFi,
ttnn.bfloat8_b,
ttnn.bfloat8_b,
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
pad_h,
pad_w,
use_1d_systolic_array,
config_override=config_override,
use_shallow_conv_variant=use_shallow_conv_variant,
transpose_mcast=use_1d_systolic_array, ## use RM (transpose_mcast=False) with 2D on WH
packer_l1_acc=True,
fp32_accum=False,
has_bias=True,
auto_shard=False,
memory_config=memory_config,
)


@skip_for_grayskull()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize(
Expand Down
21 changes: 16 additions & 5 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,8 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
std::array<uint32_t, 2> dilation,
uint32_t groups,
std::optional<const ttnn::Tensor> bias_tensor,
std::optional<const Conv2dConfig> conv_config_) {
std::optional<const Conv2dConfig> conv_config_,
const std::optional<const MemoryConfig> memory_config) {

Conv2dConfig conv_config = conv_config_.value_or(Conv2dConfig());
uint32_t output_height = ((input_height - kernel_size[0] - ((kernel_size[0] - 1 ) * (dilation[0] - 1)) + 2 * padding[0]) / stride[0]) + 1;
Expand Down Expand Up @@ -885,6 +886,11 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
conv_config.enable_split_reader,
conv_config.enable_subblock_padding);
ttnn::operations::core::deallocate(halo_output);

if (memory_config.has_value() && memory_config.value() != conv_output.memory_config()) {
conv_output = ttnn::to_memory_config(conv_output, memory_config.value(), std::nullopt);
}

return {conv_output, output_height, output_width, weight_tensor_on_device, bias_tensor_on_device};
} else {
// run conv as matmul
Expand All @@ -902,7 +908,6 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
matmul_input = ttnn::operations::downsample::downsample(
input_tensor_post_tm, {batch_size, input_height, input_width, stride[0], stride[1]});
if (conv_config.deallocate_activation) {
// input_tensor_post_tm.deallocate();
ttnn::operations::core::deallocate(input_tensor_post_tm);
}
}
Expand All @@ -917,9 +922,13 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
conv_config.dtype,
compute_kernel_config});
if (conv_config.deallocate_activation) {
// matmul_input.deallocate();
ttnn::operations::core::deallocate(matmul_input);
}

if (memory_config.has_value() && memory_config.value() != matmul_output.memory_config()) {
matmul_output = ttnn::to_memory_config(matmul_output, memory_config.value(), std::nullopt);
}

return {matmul_output, output_height, output_width, weight_tensor_on_device, bias_tensor_on_device};
}
}
Expand Down Expand Up @@ -1031,7 +1040,8 @@ template std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optiona
std::array<uint32_t, 2> dilation,
uint32_t groups,
std::optional<const ttnn::Tensor> bias_tensor,
std::optional<const Conv2dConfig> conv_config_);
std::optional<const Conv2dConfig> conv_config_,
const std::optional<const MemoryConfig> memory_config);

template std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::Tensor>> conv2d<MeshDevice>(
const ttnn::Tensor& input_tensor,
Expand All @@ -1048,7 +1058,8 @@ template std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optiona
std::array<uint32_t, 2> dilation,
uint32_t groups,
std::optional<const ttnn::Tensor> bias_tensor,
std::optional<const Conv2dConfig> conv_config_);
std::optional<const Conv2dConfig> conv_config_,
const std::optional<const MemoryConfig> memory_config);

} // namespace conv2d
} // namespace operations
Expand Down
13 changes: 8 additions & 5 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
std::array<uint32_t, 2> dilation,
uint32_t groups,
std::optional<const ttnn::Tensor> bias_tensor = std::nullopt,
std::optional<const Conv2dConfig> conv_config_ = std::nullopt);
std::optional<const Conv2dConfig> conv_config_ = std::nullopt,
const std::optional<const MemoryConfig> memory_config = std::nullopt);


struct Conv2dOperation{
Expand All @@ -211,8 +212,9 @@ struct Conv2dOperation{
std::array<uint32_t, 2> dilation,
uint32_t groups,
std::optional<const ttnn::Tensor> bias_tensor = std::nullopt,
std::optional<const Conv2dConfig> conv_config_ = std::nullopt){
return conv2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, bias_tensor, conv_config_);
std::optional<const Conv2dConfig> conv_config_ = std::nullopt,
const std::optional<const MemoryConfig> memory_config = std::nullopt){
return conv2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, bias_tensor, conv_config_, memory_config);
}

static std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::Tensor>> invoke(
Expand All @@ -231,8 +233,9 @@ struct Conv2dOperation{
std::array<uint32_t, 2> dilation,
uint32_t groups,
std::optional<const ttnn::Tensor> bias_tensor = std::nullopt,
std::optional<const Conv2dConfig> conv_config_ = std::nullopt){
return conv2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, bias_tensor, conv_config_);
std::optional<const Conv2dConfig> conv_config_ = std::nullopt,
const std::optional<const MemoryConfig> memory_config = std::nullopt){
return conv2d(input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, bias_tensor, conv_config_, memory_config);
}
};

Expand Down
8 changes: 6 additions & 2 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ void py_bind_conv2d(py::module& module) {
uint32_t groups,
std::optional<const ttnn::Tensor> bias_tensor,
std::optional<const Conv2dConfig> conv_config,
const std::optional<const MemoryConfig> memory_config,
const uint8_t& queue_id) -> std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::Tensor>> {
return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, bias_tensor, conv_config);
return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, bias_tensor, conv_config, memory_config);
},
py::kw_only(),
py::arg("input_tensor"),
Expand All @@ -75,6 +76,7 @@ void py_bind_conv2d(py::module& module) {
py::arg("groups"),
py::arg("bias_tensor") = std::nullopt,
py::arg("conv_config") = std::nullopt,
py::arg("memory_config") = std::nullopt,
py::arg("queue_id") = 0},

ttnn::pybind_overload_t{
Expand All @@ -93,8 +95,9 @@ void py_bind_conv2d(py::module& module) {
uint32_t groups,
std::optional<const ttnn::Tensor> bias_tensor,
std::optional<const Conv2dConfig> conv_config,
const std::optional<const MemoryConfig> memory_config,
const uint8_t& queue_id) -> std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::Tensor>> {
return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, bias_tensor, conv_config);
return self(queue_id, input_tensor, weight_tensor, device, in_channels, out_channels, batch_size, input_height, input_width, kernel_size, stride, padding, dilation, groups, bias_tensor, conv_config, memory_config);
},
py::kw_only(),
py::arg("input_tensor"),
Expand All @@ -112,6 +115,7 @@ void py_bind_conv2d(py::module& module) {
py::arg("groups"),
py::arg("bias_tensor") = std::nullopt,
py::arg("conv_config") = std::nullopt,
py::arg("memory_config") = std::nullopt,
py::arg("queue_id") = 0}
);

Expand Down
28 changes: 17 additions & 11 deletions ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace ttnn {
namespace operations::pool {

Tensor MaxPool2DOp::invoke(uint8_t queue_id, const Tensor& input_tensor, uint32_t batch_size, uint32_t input_h, uint32_t input_w, uint32_t channels, std::array<uint32_t, 2> kernel_size, std::array<uint32_t, 2> stride, std::array<uint32_t, 2> padding, std::array<uint32_t, 2> dilation) {
Tensor MaxPool2DOp::invoke(uint8_t queue_id, const Tensor& input_tensor, uint32_t batch_size, uint32_t input_h, uint32_t input_w, uint32_t channels, std::array<uint32_t, 2> kernel_size, std::array<uint32_t, 2> stride, std::array<uint32_t, 2> padding, std::array<uint32_t, 2> dilation, const std::optional<const MemoryConfig> memory_config) {

sliding_window::SlidingWindowConfig sliding_window_config{
.batch_size = batch_size,
Expand All @@ -30,10 +30,10 @@ Tensor MaxPool2DOp::invoke(uint8_t queue_id, const Tensor& input_tensor, uint32_
bool is_in_tiled = input_tensor.dtype() == DataType::BFLOAT8_B; // input tiled for bfp8_b

sliding_window::ParallelConfig parallel_config;
MemoryConfig memory_config = input_tensor_sharded.memory_config();
MemoryConfig out_memory_config = input_tensor_sharded.memory_config();
uint32_t num_cores_nhw = 0;

if (!memory_config.shard_spec.has_value()) {
if (!out_memory_config.shard_spec.has_value()) {
// Input is not sharded. Perform sharding.
parallel_config = conv::conv2d::determine_parallel_config(
TensorMemoryLayout::HEIGHT_SHARDED,
Expand All @@ -48,12 +48,12 @@ Tensor MaxPool2DOp::invoke(uint8_t queue_id, const Tensor& input_tensor, uint32_
num_cores_nhw = conv::conv2d::get_num_cores_nhw_from_parallel_config(parallel_config);
auto sharded_mem_config = conv::conv2d::create_sharded_memory_config_from_parallel_config(input_tensor_sharded.shape(), parallel_config, is_in_tiled ? tt::constants::TILE_HEIGHT : 1);
input_tensor_sharded = ttnn::to_memory_config(input_tensor_sharded, sharded_mem_config, std::nullopt);
memory_config = input_tensor_sharded.memory_config();
out_memory_config = input_tensor_sharded.memory_config();
} else {
// input is already sharded, use it as is
const auto shard_grid = memory_config.shard_spec.value().grid;
const auto shard_scheme = memory_config.memory_layout;
const auto shard_orientation = memory_config.shard_spec.value().orientation;
const auto shard_grid = out_memory_config.shard_spec.value().grid;
const auto shard_scheme = out_memory_config.memory_layout;
const auto shard_orientation = out_memory_config.shard_spec.value().orientation;
TT_FATAL(shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED, "Only height sharded tensors are supported.");
TT_FATAL(shard_orientation == ShardOrientation::ROW_MAJOR, "Only row major orientation is supported.");
parallel_config.grid = shard_grid;
Expand All @@ -62,13 +62,13 @@ Tensor MaxPool2DOp::invoke(uint8_t queue_id, const Tensor& input_tensor, uint32_
num_cores_nhw = conv::conv2d::get_num_cores_nhw_from_parallel_config(parallel_config);
}
// update the shard spec to match the output shape
auto shard_spec = memory_config.shard_spec.value();
auto shard_spec = out_memory_config.shard_spec.value();
uint32_t output_shard_width_padded = input_tensor.dtype() == DataType::BFLOAT8_B ? tt::round_up(output_shape[3], tt::constants::TILE_WIDTH) : tt::round_up(output_shape[3] * tt::datum_size(tt::tt_metal::datatype_to_dataformat_converter(input_tensor.dtype())), tt::constants::TILE_WIDTH);
uint32_t output_nhw = output_shape[0] * output_shape[1] * output_shape[2];
uint32_t output_nhw_padded = tt::round_up(output_nhw, num_cores_nhw * (is_out_tiled ? tt::constants::TILE_HEIGHT : 1));
uint32_t output_shard_height_padded = output_nhw_padded / num_cores_nhw;
log_debug(tt::LogOp, "output_nhw: {}, output_nhw_padded: {}, output_shard_height_padded: {}, output_shard_width_padded: {}", output_nhw, output_nhw_padded, output_shard_height_padded, output_shard_width_padded);
memory_config.shard_spec = ShardSpec{shard_spec.grid, {output_shard_height_padded, output_shard_width_padded}, ShardOrientation::ROW_MAJOR, false};
out_memory_config.shard_spec = ShardSpec{shard_spec.grid, {output_shard_height_padded, output_shard_width_padded}, ShardOrientation::ROW_MAJOR, false};

sliding_window_config = sliding_window::SlidingWindowConfig{
.batch_size = batch_size,
Expand All @@ -95,12 +95,18 @@ Tensor MaxPool2DOp::invoke(uint8_t queue_id, const Tensor& input_tensor, uint32_
input_tensor_sharded.memory_config(),
is_out_tiled);

return ttnn::prim::max_pool2d(
auto output_tensor = ttnn::prim::max_pool2d(
queue_id,
haloed_tensor,
sliding_window_config,
DataType::BFLOAT16, // input_tensor.dtype(), // currently only bfp16 output is supported
memory_config);
out_memory_config);

if (memory_config.has_value() && memory_config.value() != out_memory_config) {
output_tensor = ttnn::to_memory_config(output_tensor, memory_config.value(), std::nullopt);
}

return output_tensor;
}

} // namespace operations::pool
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace ttnn {
namespace operations::pool {

struct MaxPool2DOp {
static Tensor invoke(uint8_t queue_id, const Tensor& input_tensor, uint32_t batch_size, uint32_t input_h, uint32_t input_w, uint32_t channels, std::array<uint32_t, 2> kernel_size, std::array<uint32_t, 2> stride, std::array<uint32_t, 2> padding, std::array<uint32_t, 2> dilation);
static Tensor invoke(uint8_t queue_id, const Tensor& input_tensor, uint32_t batch_size, uint32_t input_h, uint32_t input_w, uint32_t channels, std::array<uint32_t, 2> kernel_size, std::array<uint32_t, 2> stride, std::array<uint32_t, 2> padding, std::array<uint32_t, 2> dilation, const std::optional<const MemoryConfig> memory_config = std::nullopt);

};

Expand Down
Loading

0 comments on commit f3c290a

Please sign in to comment.