Skip to content

Commit

Permalink
Fix to concat support for tensors with tile padding (#15513)
Browse files Browse the repository at this point in the history
  • Loading branch information
jaykru-tt authored Nov 29, 2024
1 parent a2b1b0b commit 8c02d35
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 13 deletions.
25 changes: 22 additions & 3 deletions tests/ttnn/unit_tests/operations/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,28 @@
from tests.ttnn.utils_for_testing import assert_with_pcc


@pytest.mark.parametrize(
"concat_spec",
(([[1, 1, 12, 50], [1, 1, 12, 50]], -1),),
)
@pytest.mark.parametrize("async_mode", [True, False], ids=["async_on", "async_off"])
def test_tiled_concat(device, concat_spec, async_mode):
shapes, dim = concat_spec
device.enable_async(async_mode)
torch_input_tensors = [torch.rand(shape, dtype=torch.bfloat16) for shape in shapes]
torch_output_tensor = torch.concat(torch_input_tensors, dim=dim)

input_tensors = [
ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
for torch_input_tensor in torch_input_tensors
]

output = ttnn.concat(input_tensors, dim=dim)
output = ttnn.to_torch(output)

assert_with_pcc(torch_output_tensor, output, 0.9999)


@pytest.mark.parametrize("height", [20, 32])
@pytest.mark.parametrize("width", [4, 32])
@pytest.mark.parametrize("dim", [0, 1])
Expand All @@ -24,9 +46,6 @@ def test_concat(device, height, width, dim, async_mode):
input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)

if ttnn.has_tile_padding(input_tensor_a, dim=dim) or ttnn.has_tile_padding(input_tensor_b, dim=dim):
pytest.skip("Cannot concat tensors with tile padding")

output = ttnn.concat([input_tensor_a, input_tensor_b], dim=dim)
output = ttnn.to_torch(output)

Expand Down
59 changes: 49 additions & 10 deletions ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include "ttnn/cpp/ttnn/operations/data_movement/common/common.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/slice/device/slice_op.hpp"

#include <ranges>
#include <utility>
Expand Down Expand Up @@ -85,7 +87,9 @@ MassagedConcat build_unsqueeze_concat(int input_rank, const MemoryConfig& output
}});
}

MassagedConcat build_untilize_rm_retilize_concat(uint8_t queue_id, const MemoryConfig& output_memory_config) {
MassagedConcat build_untilize_rm_retilize_concat(uint8_t queue_id,
const MemoryConfig& output_memory_config,
ttnn::SimpleShape logical_output_shape) {
return MassagedConcat(MassagedConcatParams{
.predicate = [](const std::vector<ttnn::Tensor>& tensors, int dim, unsigned int groups) -> bool {
// untilize_rm_retilize if the concat dim is padded for tilized tensors
Expand All @@ -96,41 +100,64 @@ MassagedConcat build_untilize_rm_retilize_concat(uint8_t queue_id, const MemoryC
concat_db_print(res, "untilize_rm_retilize required");
return res;
},
.pre_transform = [](const std::vector<ttnn::Tensor>& tensors, int dim, unsigned int groups) -> OwnedConcatArgs {
.pre_transform =
[queue_id, output_memory_config](const std::vector<ttnn::Tensor>& tensors, int dim, unsigned int groups) -> OwnedConcatArgs {
std::vector<ttnn::Tensor> itensors;
itensors.reserve(tensors.size());
std::transform(
tensors.begin(),
tensors.end(),
std::back_inserter(itensors),
[](const ttnn::Tensor& input_tensor) -> ttnn::Tensor {
[=](const ttnn::Tensor& input_tensor) -> ttnn::Tensor {
TT_FATAL(
input_tensor.get_layout() == ttnn::TILE_LAYOUT,
"ttnn.concat: expected all input tensors to be in tile layout");
auto untilized_tensor = ttnn::untilize(input_tensor);
// untilized, so now we have a padded rm tensor
// untilized, so now we have a padded rm tensor. we slice to
// remove the padding.
std::vector<uint32_t> begins_vec(input_tensor.get_shape().rank(), 0);
tt::stl::Span<const uint32_t> begins = begins_vec;
tt::stl::Span<const uint32_t> ends = input_tensor.get_logical_shape().view();
std::vector<uint32_t> steps_vec(input_tensor.get_shape().rank(), 1);
tt::stl::Span<const uint32_t> steps = steps_vec;

// we now perform a padding-oblivious slice to remove the
// tile padding.
// FIXME: change this to a legit slice call once
// padding-oblivious entry point is uplifted to the slice
// op.
untilized_tensor = operation::run(
SliceDeviceOperation{begins, ends, steps, output_memory_config},
{untilized_tensor},
{},
{std::nullopt},
queue_id)[0];

untilized_tensor = ttnn::reshape(
untilized_tensor,
ttnn::Shape{
input_tensor.get_logical_shape().view(), untilized_tensor.get_padded_shape().view()});
ttnn::Shape{input_tensor.get_logical_shape().view(), input_tensor.get_logical_shape().view()});
return untilized_tensor;
});
return std::make_tuple(itensors, dim, groups);
},
.post_transform = [queue_id](const ttnn::Tensor& output) -> ttnn::Tensor {
.post_transform = [&logical_output_shape,
queue_id](const ttnn::Tensor& output) -> ttnn::Tensor {
// now we have a rm tensor, so we need ensure its's padded to tile size and re-tilize it
if (output.get_layout() != ttnn::TILE_LAYOUT) {
auto padded = pad_to_tile_vol(queue_id, output, 0.0f, true, output.memory_config());
concat_db_print(true, "[DEBUG] padded to tile layout, now tilizing.");
auto tilized =
ttnn::tilize_with_val_padding(padded, padded.get_legacy_shape(), 0.0f, output.memory_config());
concat_db_print(true, "[DEBUG] tilized");
return tilized;
// need to reshape tilized result to logical concat output shape
auto reshaped = ttnn::reshape(
tilized, ttnn::Shape{logical_output_shape.view(), tilized.get_padded_shape().view()});
return reshaped;
}
concat_db_print(true, "[DEBUG] already tilized");
return output;
},
.operation = [output_memory_config](
.operation = [&output_memory_config](
const std::vector<ttnn::Tensor>& tensors, int dim, unsigned int groups) -> ttnn::Tensor {
std::vector<ttnn::Tensor> itensors(tensors);
auto res = concat_impl(itensors, dim, groups, output_memory_config);
Expand Down Expand Up @@ -279,7 +306,19 @@ ttnn::Tensor ConcatOperation::invoke(
shapes_match,
"All dimensions must be the same size except for the dimension along which the contenation is taking place.");

auto untilize_rm_retilize_concat = build_untilize_rm_retilize_concat(queue_id, mem_config);
auto compute_output_shape = [](const std::vector<ttnn::Tensor>& tensors, int dim) -> ttnn::SimpleShape {
ttnn::SimpleShape shape_out = tensors[0].get_logical_shape();
shape_out[dim] = 0;
for (const Tensor& in_ref : tensors) {
ttnn::SimpleShape curr_shape = in_ref.get_logical_shape();
shape_out[dim] += curr_shape[dim];
}
return shape_out;
};

ttnn::SimpleShape logical_output_shape = compute_output_shape(input_tensors, dim);

auto untilize_rm_retilize_concat = build_untilize_rm_retilize_concat(queue_id, mem_config, logical_output_shape);
auto non_aligned_last_dim_concat = build_non_aligned_last_dim_concat(input_tensors, queue_id, mem_config);
auto massaged_concat = untilize_rm_retilize_concat.sequence(non_aligned_last_dim_concat);

Expand Down

0 comments on commit 8c02d35

Please sign in to comment.