From fa69b0b81b9f9f452bfc51abca3180942d24446e Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Wed, 9 Oct 2024 16:56:29 -0700 Subject: [PATCH] #13127: Allow `compute_output_shapes` to use SimpleShape instead of LegacyShape, port some ops to SimpleShape (#13645) * #13127: Prototype of moving some operation to ttnn::SimpleShape * #13127: Port more ops * #13127: Infra support for SimpleShape * #13127: Revert convolution changes * #13127: Refactor to remove code duplication * #13127: Fix rebase issues * #13127: Extract extract_legacy_shape function, make get_physical_shape pure function --- docs/source/ttnn/ttnn/dependencies/tt_lib.rst | 8 +++--- .../unit_tests/gtests/test_ccl_on_galaxy.cpp | 9 ++++++- .../moreh_clip_grad_norm_op.cpp | 6 ++--- .../moreh_clip_grad_norm_op.hpp | 6 ++--- .../op_library/moreh_dot/moreh_dot_op.cpp | 10 +++----- .../op_library/moreh_dot/moreh_dot_op.hpp | 2 +- .../moreh_dot_backward_op.cpp | 2 +- .../moreh_dot_backward_op.hpp | 2 +- .../moreh_layernorm/moreh_layernorm_op.cpp | 25 +++++-------------- .../moreh_layernorm/moreh_layernorm_op.hpp | 2 +- .../moreh_layernorm_backward_op.cpp | 6 ++--- .../moreh_layernorm_backward_op.hpp | 4 +-- .../moreh_matmul/moreh_matmul_op.cpp | 21 ++++++---------- .../moreh_matmul/moreh_matmul_op.hpp | 2 +- ttnn/cpp/ttnn/operation.hpp | 7 +++--- .../ccl/all_gather/device/all_gather_op.cpp | 6 ++--- .../ccl/all_gather/device/all_gather_op.hpp | 2 +- .../device/reduce_scatter_op.cpp | 6 ++--- .../device/reduce_scatter_op.hpp | 2 +- .../bcast/device/bcast_device_operation.cpp | 8 +++--- .../bcast/device/bcast_device_operation.hpp | 2 +- .../concat/device/concat_device_operation.cpp | 6 ++--- .../concat/device/concat_device_operation.hpp | 2 +- .../device/all_gather_matmul_op.cpp | 8 +++--- .../device/all_gather_matmul_op.hpp | 2 +- .../operations/matmul/device/matmul_op.cpp | 15 ++++------- .../operations/matmul/device/matmul_op.hpp | 2 +- .../device/moreh_dot_device_operation.cpp | 10 +++----- .../device/moreh_dot_device_operation.hpp | 2 +- .../moreh_dot_backward_device_operation.hpp | 2 +- .../moreh_group_norm_device_operation.cpp | 16 +++++------- .../moreh_group_norm_device_operation.hpp | 2 +- .../reduction/prod/device/prod_nc_op.cpp | 2 +- .../reduction/prod/device/prod_nc_op.hpp | 2 +- ttnn/cpp/ttnn/run_operation.cpp | 23 +++++++++++++++-- ttnn/cpp/ttnn/tensor/tensor.cpp | 5 ++++ ttnn/cpp/ttnn/tensor/tensor.hpp | 8 ++++++ ttnn/cpp/ttnn/tensor/types.cpp | 25 +++++++++++++++++++ ttnn/cpp/ttnn/tensor/types.hpp | 2 ++ 39 files changed, 153 insertions(+), 119 deletions(-) diff --git a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst index 73ae0517fe9..7a87a746005 100644 --- a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst +++ b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst @@ -34,7 +34,7 @@ New Device Operation struct { void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; }; @@ -48,7 +48,7 @@ New Device Operation with a member int some_member void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; }; @@ -61,7 +61,7 @@ New Device Operation with Optional Input Tensors struct { void validate(const std::vector &input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, @@ -80,7 +80,7 @@ and create_output_tensors with the additional parameter for the output_tensors. struct { void validate_with_output_tensors(const std::vector &input_tensors, const std::vector>& output_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector> create_output_tensors(const std::vector &input_tensors, const std::vector>& output_tensors) const; operation::ProgramWithOptionalOutputTensors create_program(const std::vector& input_tensors, std::vector> &output_tensors) const; diff --git a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp index 027537ae3a5..df3476bd545 100644 --- a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp +++ b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp @@ -27,7 +27,14 @@ std::vector run_operation( const operation::OptionalTensors& optional_output_tensors = {}) { static_assert(operation::detail::is_device_operation(), "ttnn::run_operation can only dispatch Device Operations!"); // Create output tensor vector by examining the number of output shapes created by the device operation - std::vector outputs(operation::DeviceOperation(devop).compute_output_shapes(input_tensors).size()); + auto output_shapes = operation::DeviceOperation(devop).compute_output_shapes(input_tensors); + size_t output_shapes_size = 0; + if (std::holds_alternative>(output_shapes)) { + output_shapes_size = std::get>(output_shapes).size(); + } else { + output_shapes_size = std::get>(output_shapes).size(); + } + std::vector outputs(output_shapes_size); // Populate the workers of the output tensors, based on the input tensors. This is needed for the async engine. for (int i = 0; i < outputs.size(); i++) { outputs[i] = Tensor(operation::get_workers_for_op_output(std::move(input_tensors), std::move(optional_input_tensors))); diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.cpp index b054fbedff6..20276ea7676 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.cpp @@ -48,7 +48,7 @@ void MorehClipGradNormStep1::validate( check_tensor(tmp_pow_sum, "moreh_clip_grad_norm_step1", "tmp_pow_sum"); }; -std::vector MorehClipGradNormStep1::compute_output_shapes(const std::vector &) const { return {}; } +std::vector MorehClipGradNormStep1::compute_output_shapes(const std::vector &) const { return {}; } std::vector MorehClipGradNormStep1::create_output_tensors(const std::vector &) const { return {}; } @@ -105,7 +105,7 @@ void MorehClipGradNormStep2::validate(const std::vector &input_tensors) check_tensor(total_norm, "moreh_clip_grad_norm_step2", "total_norm"); } -std::vector MorehClipGradNormStep2::compute_output_shapes(const std::vector &) const { return {}; } +std::vector MorehClipGradNormStep2::compute_output_shapes(const std::vector &) const { return {}; } std::vector MorehClipGradNormStep2::create_output_tensors(const std::vector &) const { return {}; } @@ -146,7 +146,7 @@ void MorehClipGradNormStep3::validate( check_tensor(clip_coef_clamped, "moreh_clip_grad_norm_step3", "clip_coef_clamped"); } -std::vector MorehClipGradNormStep3::compute_output_shapes(const std::vector &) const { return {}; } +std::vector MorehClipGradNormStep3::compute_output_shapes(const std::vector &) const { return {}; } std::vector MorehClipGradNormStep3::create_output_tensors(const std::vector &) const { return {}; } diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.hpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.hpp index c946befe11d..3e84fee79c3 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.hpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.hpp @@ -32,7 +32,7 @@ struct MorehClipGradNormStep1 { void validate( const std::vector &input_tensors, const std::vector> &optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector &) const; + std::vector compute_output_shapes(const std::vector &) const; std::vector create_output_tensors(const std::vector &) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, @@ -49,7 +49,7 @@ struct MorehClipGradNormStep2 { float norm_type; void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &) const; + std::vector compute_output_shapes(const std::vector &) const; std::vector create_output_tensors(const std::vector &) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, std::vector &) const; @@ -64,7 +64,7 @@ struct MorehClipGradNormStep3 { void validate( const std::vector &input_tensors, const std::vector> &optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector &) const; + std::vector compute_output_shapes(const std::vector &) const; std::vector create_output_tensors(const std::vector &) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot/moreh_dot_op.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot/moreh_dot_op.cpp index dabf8dd64ba..2b38e2f6bf4 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot/moreh_dot_op.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot/moreh_dot_op.cpp @@ -46,13 +46,11 @@ void MorehDot::validate(const std::vector& input_tensors) const { "Operands to matmul need to be allocated in buffers on device!"); } -std::vector MorehDot::compute_output_shapes(const std::vector& input_tensors) const { +std::vector MorehDot::compute_output_shapes(const std::vector& input_tensors) const { const auto& input_tensor = input_tensors.at(0); - auto output_shape = input_tensor.get_legacy_shape(); - auto padding = output_shape.padding(); - output_shape[3] = TILE_WIDTH; - padding[3] = Padding::PadDimension{0, 31}; - return {tt::tt_metal::LegacyShape(output_shape, padding)}; + auto output_shape = input_tensor.get_logical_shape(); + output_shape[3] = 1; + return {output_shape}; } std::vector MorehDot::create_output_tensors(const std::vector& input_tensors) const { diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot/moreh_dot_op.hpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot/moreh_dot_op.hpp index 0be70fedaa7..288b8b07729 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot/moreh_dot_op.hpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot/moreh_dot_op.hpp @@ -27,7 +27,7 @@ struct MorehDot { const DataType output_dtype; // TODO: Uplift output_dtype as an option for general dot/bmm void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, std::vector &output_tensors) const; diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot_backward/moreh_dot_backward_op.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot_backward/moreh_dot_backward_op.cpp index dfa0fec1c50..7aa118e7a93 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot_backward/moreh_dot_backward_op.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot_backward/moreh_dot_backward_op.cpp @@ -63,7 +63,7 @@ void MorehDotBackward::validate( } } -std::vector MorehDotBackward::compute_output_shapes(const std::vector& inputs) const { +std::vector MorehDotBackward::compute_output_shapes(const std::vector& inputs) const { // Inplace return {}; } diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot_backward/moreh_dot_backward_op.hpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot_backward/moreh_dot_backward_op.hpp index 6e073dd5723..c00a1b7b96a 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot_backward/moreh_dot_backward_op.hpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot_backward/moreh_dot_backward_op.hpp @@ -29,7 +29,7 @@ operation::ProgramWithCallbacks moreh_dot_backward_single_core( struct MorehDotBackward { void validate( const std::vector &inputs, const std::vector> &optional_inputs) const; - std::vector compute_output_shapes(const std::vector &inputs) const; + std::vector compute_output_shapes(const std::vector &inputs) const; std::vector create_output_tensors(const std::vector &inputs) const; operation::ProgramWithCallbacks create_program( const std::vector &inputs, diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.cpp index bfac7fe8658..7f09ff62e74 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.cpp @@ -405,39 +405,26 @@ void MorehLayerNorm::validate_with_output_tensors( } } -std::vector MorehLayerNorm::compute_output_shapes(const std::vector& input_tensors) const { +std::vector MorehLayerNorm::compute_output_shapes(const std::vector& input_tensors) const { auto input = input_tensors.at(0); // compute mean_rstd_shape - tt::tt_metal::LegacyShape input_shape = input.get_legacy_shape(); - auto input_shape_without_padding = input_shape.without_padding(); + auto input_shape = input.get_logical_shape(); auto input_rank = input_shape.rank(); auto output_rank = input_rank - normalized_dims; - std::vector output_size_vec; - auto dimensions_pads = std::vector(); + std::vector output_shape_vec; // special case handling if (output_rank == 1) { - output_size_vec.push_back(32); - dimensions_pads.push_back(Padding::PadDimension{.front = 0, .back = 31}); + output_shape_vec.push_back(1); } for (uint32_t dim = 0 ; dim < output_rank; dim++) { - auto input_shape_without_padding_size = input_shape_without_padding[dim]; - if (is_hw_dim(dim, output_rank)) { - output_size_vec.push_back(round_up_to_mul32(input_shape_without_padding_size)); - - auto padding_back = output_size_vec[dim] - input_shape_without_padding_size; - dimensions_pads.push_back(Padding::PadDimension{.front = 0, .back = padding_back}); - } else { - output_size_vec.push_back(input_shape_without_padding_size); - dimensions_pads.push_back(Padding::PadDimension{.front = 0, .back = 0}); - } + output_shape_vec.push_back(input_shape[dim]); } - const auto padding = Padding(dimensions_pads, Padding::PadValue::Any); - auto mean_rstd_output_shape = tt::tt_metal::LegacyShape(output_size_vec, padding); + ttnn::SimpleShape mean_rstd_output_shape(std::move(output_shape_vec)); return {input_shape, mean_rstd_output_shape, mean_rstd_output_shape}; } diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.hpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.hpp index 18edae72fd4..de0111eb97e 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.hpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.hpp @@ -34,7 +34,7 @@ struct MorehLayerNorm { const std::vector &input_tensors, const std::vector> &optional_input_tensors, const std::vector> &output_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors( const std::vector& input_tensors, const std::vector>& output_tensors) const; operation::ProgramWithCallbacks create_program( diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/moreh_layernorm_backward_op.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/moreh_layernorm_backward_op.cpp index 2f09730b4a6..f0ff2ff97f1 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/moreh_layernorm_backward_op.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/moreh_layernorm_backward_op.cpp @@ -62,10 +62,10 @@ void MorehLayerNormBackwardInputGrad::validate_with_output_tensors( } } -std::vector MorehLayerNormBackwardInputGrad::compute_output_shapes( +std::vector MorehLayerNormBackwardInputGrad::compute_output_shapes( const std::vector& input_tensors) const { auto input = input_tensors.at(0); - auto input_shape = input.get_legacy_shape(); + auto input_shape = input.get_logical_shape(); // The shapes of the input and output are always the same. return {input_shape}; @@ -131,7 +131,7 @@ void MorehLayerNormBackwardGammaBetaGrad::validate_with_output_tensors( } } -std::vector MorehLayerNormBackwardGammaBetaGrad::compute_output_shapes( +std::vector MorehLayerNormBackwardGammaBetaGrad::compute_output_shapes( const std::vector& input_tensors) const { TT_THROW("The compute_output_shapes function in MorehLayerNormBackwardGammaBetaGrad is not implemented."); return {}; diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/moreh_layernorm_backward_op.hpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/moreh_layernorm_backward_op.hpp index 6e46832f6e7..f66c8f1061c 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/moreh_layernorm_backward_op.hpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/moreh_layernorm_backward_op.hpp @@ -30,7 +30,7 @@ struct MorehLayerNormBackwardInputGrad { const std::vector &input_tensors, const std::vector> &optional_input_tensors, const std::vector>& output_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector& input_tensors, const std::vector>& output_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, @@ -46,7 +46,7 @@ struct MorehLayerNormBackwardGammaBetaGrad { void validate_with_output_tensors( const std::vector &input_tensors, const std::vector>& output_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector& input_tensors, const std::vector>& output_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, std::vector &output_tensors) const; diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.cpp index 7de14906c48..3a491640609 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.cpp @@ -32,18 +32,16 @@ inline bool is_dot_forward(const Tensor& input, const Tensor& other, bool transp return is_1d_tensor(input) && is_1d_tensor(other) && is_same_shape(input, other); } -tt::tt_metal::LegacyShape compute_output_shape( +ttnn::SimpleShape compute_output_shape( const tt::tt_metal::LegacyShape& input_shape, const tt::tt_metal::LegacyShape& other_shape, bool transpose_input, bool transpose_other) { - const auto& input_shape_wo_padding = input_shape.without_padding(); - const auto& other_shape_wo_padding = other_shape.without_padding(); + const auto& logical_input_shape = input_shape.logical_shape(); + const auto& logical_other_shape = other_shape.logical_shape(); - auto h = (transpose_input) ? (input_shape[-1]) : (input_shape[-2]); - auto w = (transpose_other) ? (other_shape[-2]) : (other_shape[-1]); - auto h_wo_padding = (transpose_input) ? (input_shape_wo_padding[-1]) : (input_shape_wo_padding[-2]); - auto w_wo_padding = (transpose_other) ? (other_shape_wo_padding[-2]) : (other_shape_wo_padding[-1]); + auto h = (transpose_input) ? (logical_input_shape[-1]) : (logical_input_shape[-2]); + auto w = (transpose_other) ? (logical_other_shape[-2]) : (logical_other_shape[-1]); std::vector input_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); std::vector other_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); @@ -72,12 +70,7 @@ tt::tt_metal::LegacyShape compute_output_shape( output_dim[output_rank - 2] = h; output_dim[output_rank - 1] = w; - tt::tt_metal::LegacyShape output_shape{output_dim}; - auto padding = output_shape.padding(); - // padding for t logmatrix dims - padding[output_rank - 2] = Padding::PadDimension{0, h - h_wo_padding}; - padding[output_rank - 1] = Padding::PadDimension{0, w - w_wo_padding}; - return {tt::tt_metal::LegacyShape(output_shape, padding)}; + return {ttnn::SimpleShape(std::move(output_dim))}; } } // namespace @@ -159,7 +152,7 @@ operation::ProgramWithCallbacks MorehMatmul::create_program( } // Must be provided in the case where an optional output tensor was not provided -std::vector MorehMatmul::compute_output_shapes( +std::vector MorehMatmul::compute_output_shapes( const std::vector& input_tensors) const { return {compute_output_shape( input_tensors.at(0).get_legacy_shape(), diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.hpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.hpp index 82ae710c93f..5297ee9add6 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.hpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.hpp @@ -39,7 +39,7 @@ struct MorehMatmul { const std::vector &input_tensors, const std::vector> &optional_input_tensors, const std::vector> &output_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors( const std::vector &input_tensors, const std::vector> &output_tensors) const; operation::ProgramWithCallbacks create_program( diff --git a/ttnn/cpp/ttnn/operation.hpp b/ttnn/cpp/ttnn/operation.hpp index 5f5efc1b85e..e20faf18fdf 100644 --- a/ttnn/cpp/ttnn/operation.hpp +++ b/ttnn/cpp/ttnn/operation.hpp @@ -384,6 +384,7 @@ template struct DeviceOperation final { using storage_t = std::array; using OutputTensors = OutputTensorsT; + using ComputedShapes = std::variant, std::vector>; inline const std::string get_type_name() const { return this->get_type_name_impl_(this->type_erased_storage); } @@ -395,7 +396,7 @@ struct DeviceOperation final { this->type_erased_storage, input_tensors, optional_input_tensors, optional_output_tensors); } - inline const std::vector compute_output_shapes(const Tensors& input_tensors) const { + inline const ComputedShapes compute_output_shapes(const Tensors& input_tensors) const { return this->compute_output_shapes_impl_(this->type_erased_storage, input_tensors); } @@ -544,7 +545,7 @@ struct DeviceOperation final { } }}, compute_output_shapes_impl_{ - [](const storage_t& storage, const Tensors& input_tensors) -> const std::vector { + [](const storage_t& storage, const Tensors& input_tensors) -> const ComputedShapes { const auto& operation = *reinterpret_cast*>(&storage); return operation.compute_output_shapes(input_tensors); }}, @@ -753,7 +754,7 @@ struct DeviceOperation final { const Tensors&, const std::vector>&, const OptionalTensors&); - const std::vector (*compute_output_shapes_impl_)(const storage_t& value, const Tensors&); + const ComputedShapes (*compute_output_shapes_impl_)(const storage_t& value, const Tensors&); const OutputTensors (*create_output_tensors_impl_)(const storage_t& value, const Tensors&, const OptionalTensors&); CacheableProgram (*create_program_impl_)( diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp index 026cf6d0ddb..3461d5f1da7 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp @@ -165,10 +165,10 @@ void AllGather::validate(const std::vector &input_tensors) const { } } -std::vector AllGather::compute_output_shapes(const std::vector &input_tensors) const { - auto shape = input_tensors[0].get_legacy_shape(); +std::vector AllGather::compute_output_shapes(const std::vector &input_tensors) const { + auto shape = input_tensors[0].get_logical_shape(); shape[this->dim] *= this->ring_size; - return std::vector(input_tensors.size(), shape); + return std::vector(input_tensors.size(), shape); } std::vector AllGather::create_output_tensors(const std::vector &input_tensors) const { diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp index 041561bcc87..607bb80af49 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp @@ -132,7 +132,7 @@ struct AllGather { const ccl::Topology topology; void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; }; diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp index a573b0ff262..c87d5e35a93 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp @@ -22,13 +22,13 @@ void ReduceScatter::validate(const std::vector& input_tensors) const { } } -std::vector ReduceScatter::compute_output_shapes(const std::vector& input_tensors) const { - auto shape = input_tensors[0].get_legacy_shape(); +std::vector ReduceScatter::compute_output_shapes(const std::vector& input_tensors) const { + auto shape = input_tensors[0].get_logical_shape(); TT_FATAL( shape[this->scatter_dim] % this->ring_size == 0, "The size of the scatter dimension must be a multiple of the ring size"); shape[this->scatter_dim] /= this->ring_size; - return std::vector(input_tensors.size(), shape); + return std::vector(input_tensors.size(), shape); } std::vector ReduceScatter::create_output_tensors(const std::vector& input_tensors) const { diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp index 752a42020a4..996d3078ca0 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp @@ -25,7 +25,7 @@ struct ReduceScatter { const std::optional user_defined_num_buffers_per_channel; void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, std::vector &output_tensors) const; diff --git a/ttnn/cpp/ttnn/operations/data_movement/bcast/device/bcast_device_operation.cpp b/ttnn/cpp/ttnn/operations/data_movement/bcast/device/bcast_device_operation.cpp index 0a310ac4b75..d2dd0bc14dd 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/bcast/device/bcast_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/bcast/device/bcast_device_operation.cpp @@ -47,9 +47,9 @@ void EltwiseBinaryBroadcast::validate_with_output_tensors(const std::vector output_shape_required = this->compute_output_shapes(input_tensors); + const std::vector output_shape_required = this->compute_output_shapes(input_tensors); const auto& out_tensor = output_tensors.at(0).value(); - TT_FATAL(out_tensor.get_legacy_shape() == output_shape_required.at(0), "The input tensors need a shape of {}, however the output tensor is only {}", output_shape_required, out_tensor.get_legacy_shape()); + TT_FATAL(out_tensor.get_logical_shape() == output_shape_required.at(0), "The input tensors need a shape of {}, however the output tensor is only {}", output_shape_required, out_tensor.get_legacy_shape()); } if (this->in_place) { TT_FATAL(input_tensor_a.memory_config().memory_layout == this->output_mem_config.memory_layout, "Error"); @@ -109,9 +109,9 @@ void EltwiseBinaryBroadcast::validate_with_output_tensors(const std::vector EltwiseBinaryBroadcast::compute_output_shapes(const std::vector &input_tensors) const { +std::vector EltwiseBinaryBroadcast::compute_output_shapes(const std::vector &input_tensors) const { const auto& input_tensor = input_tensors.at(0); - return {input_tensor.get_legacy_shape()}; + return {input_tensor.get_logical_shape()}; } diff --git a/ttnn/cpp/ttnn/operations/data_movement/bcast/device/bcast_device_operation.hpp b/ttnn/cpp/ttnn/operations/data_movement/bcast/device/bcast_device_operation.hpp index a7fcb22f395..a2a56b717f3 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/bcast/device/bcast_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/bcast/device/bcast_device_operation.hpp @@ -29,7 +29,7 @@ struct EltwiseBinaryBroadcast { const bool in_place; void validate_with_output_tensors(const std::vector &input_tensors, const std::vector> &output_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors, const std::vector> &output_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, std::vector &output_tensors) const; diff --git a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp index 472b107a4e6..1e3d958c661 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp @@ -61,11 +61,11 @@ void ConcatDeviceOperation::validate(const std::vector &input_tensors) c } } -std::vector ConcatDeviceOperation::compute_output_shapes(const std::vector &input_tensors) const { - tt::tt_metal::LegacyShape shape_out = input_tensors[0].get_legacy_shape(); +std::vector ConcatDeviceOperation::compute_output_shapes(const std::vector &input_tensors) const { + ttnn::SimpleShape shape_out = input_tensors[0].get_logical_shape(); shape_out[this->dim] = 0; for (const Tensor &in_ref : input_tensors) { - tt::tt_metal::LegacyShape curr_shape = in_ref.get_legacy_shape(); + ttnn::SimpleShape curr_shape = in_ref.get_logical_shape(); shape_out[this->dim] += curr_shape[this->dim]; } return {shape_out}; diff --git a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.hpp b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.hpp index 0e5a35500a1..86fe135637b 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.hpp @@ -15,7 +15,7 @@ struct ConcatDeviceOperation { uint32_t dim; const MemoryConfig output_mem_config; void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, std::vector &output_tensors) const; diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp index 7ec8010cf47..c6be2478fff 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp @@ -54,15 +54,15 @@ void AllGatherMatmul::validate(const std::vector &input_tensors, const s } } -std::vector AllGatherMatmul::compute_output_shapes(const std::vector &input_tensors) const { +std::vector AllGatherMatmul::compute_output_shapes(const std::vector &input_tensors) const { // All Gather shape - tt::tt_metal::LegacyShape all_gather_output_shape = this->all_gather_struct.compute_output_shapes({input_tensors[0]})[0]; - tt::tt_metal::LegacyShape datacopy_output_shape = all_gather_output_shape; + ttnn::SimpleShape all_gather_output_shape = this->all_gather_struct.compute_output_shapes({input_tensors[0]})[0]; + ttnn::SimpleShape datacopy_output_shape = all_gather_output_shape; // Matmul shape - tt::tt_metal::LegacyShape matmul_output_shapes = this->matmul_struct.compute_output_shapes({input_tensors[1], input_tensors[2]})[0]; + ttnn::SimpleShape matmul_output_shapes = this->matmul_struct.compute_output_shapes({input_tensors[1], input_tensors[2]})[0]; return {all_gather_output_shape, matmul_output_shapes, datacopy_output_shape}; } diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp index 3d57614fefe..6dc88b1086d 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp @@ -42,7 +42,7 @@ struct AllGatherMatmul { /* General */ void validate(const std::vector &input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp index 4d487336e79..2a6af0f88e8 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp @@ -1298,29 +1298,24 @@ void Matmul::validate( chosen_program_config); } -std::vector Matmul::compute_output_shapes(const std::vector& input_tensors) const { - const tt::tt_metal::LegacyShape& input_shape_a = input_tensors.at(0).get_legacy_shape(); - const tt::tt_metal::LegacyShape& input_shape_b = input_tensors.at(1).get_legacy_shape(); +std::vector Matmul::compute_output_shapes(const std::vector& input_tensors) const { + ttnn::SimpleShape input_shape_a = input_tensors.at(0).get_logical_shape(); + ttnn::SimpleShape input_shape_b = input_tensors.at(1).get_logical_shape(); const uint32_t a_rank = input_shape_a.rank(); const uint32_t b_rank = input_shape_b.rank(); const uint32_t out_rank = std::max(a_rank, b_rank); const uint32_t rank_difference = out_rank - a_rank; - tt::tt_metal::LegacyShape output_shape = (b_rank > a_rank) ? input_shape_b : input_shape_a; - auto dimensions_pads = std::vector(); + ttnn::SimpleShape output_shape = (b_rank > a_rank) ? input_shape_b : input_shape_a; for (auto index = 0; index < rank_difference; index++) { TT_FATAL(input_shape_b[index] == 1, "When in1 rank greater than in0 rank front dimensions need to be 1"); output_shape[index] = input_shape_b[index]; - dimensions_pads.push_back(input_shape_b.padding()[index]); } for (auto index = 0; index < a_rank - 1; index++) { output_shape[rank_difference + index] = input_shape_a[index]; - dimensions_pads.push_back(input_shape_a.padding()[index]); } output_shape[-1] = input_shape_b[-1]; - dimensions_pads.push_back(input_shape_b.padding()[b_rank - 1]); - const auto padding = Padding(dimensions_pads, Padding::PadValue::Any); - return {tt::tt_metal::LegacyShape(output_shape, padding)}; + return {std::move(output_shape)}; } std::vector Matmul::create_output_tensors(const std::vector& input_tensors) const { diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp index b08678d871e..32eb87cd13a 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp @@ -169,7 +169,7 @@ struct Matmul { void validate( const std::vector &input_tensors, const std::vector> &optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector compute_output_shapes_dram_sharded( const std::vector &input_tensors, uint32_t N_unpadded) const; std::vector create_output_tensors(const std::vector &input_tensors) const; diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_device_operation.cpp index 151f50ae67f..ac997b24f68 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_device_operation.cpp @@ -50,14 +50,12 @@ void MorehDotOperation::validate_on_program_cache_hit( MorehDotOperation::shape_return_value_t MorehDotOperation::compute_output_shapes( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { if (tensor_args.output.has_value()) { - return tensor_args.output.value().get_shape(); + return tensor_args.output.value().get_logical_shape(); } const auto& input = tensor_args.input_a; - auto output_shape = input.get_shape().value; - auto padding = output_shape.padding(); - output_shape[3] = tt::constants::TILE_WIDTH; - padding[3] = Padding::PadDimension{0, 31}; - return ttnn::Shape{tt::tt_metal::LegacyShape(output_shape, padding)}; + auto output_shape = input.get_logical_shape(); + output_shape[3] = 1; + return ttnn::SimpleShape{std::move(output_shape)}; } MorehDotOperation::tensor_return_value_t MorehDotOperation::create_output_tensors( diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_device_operation.hpp index 7c02317988e..727b282a362 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_device_operation.hpp @@ -23,7 +23,7 @@ struct MorehDotOperation { const std::optional& output; }; - using shape_return_value_t = Shape; + using shape_return_value_t = SimpleShape; using tensor_return_value_t = Tensor; struct SingleCore { diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_device_operation.hpp index d7185780040..34ae223deb8 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_device_operation.hpp @@ -26,7 +26,7 @@ struct MorehDotBackwardOperation { const std::vector> output_tensors; }; - using shape_return_value_t = std::vector>; + using shape_return_value_t = std::vector>; using tensor_return_value_t = std::vector>; struct SingleCore { diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/device/moreh_group_norm_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/device/moreh_group_norm_device_operation.cpp index a4a6a19b365..945cfa6aafa 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/device/moreh_group_norm_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/device/moreh_group_norm_device_operation.cpp @@ -84,20 +84,16 @@ MorehGroupNormOperation::shape_return_value_t MorehGroupNormOperation::compute_o const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { using namespace tt::constants; // mean, rstd (1, 1, N, num_groups) - const auto output_shape = tensor_args.input.get_shape(); - const auto N = output_shape.value[0]; + const auto output_shape = tensor_args.input.get_logical_shape(); + const auto N = output_shape[0]; const auto num_groups = operation_attributes.num_groups; - const std::vector mean_rstd_origin_shape{ + std::vector mean_rstd_origin_shape{ 1, 1, - TILE_HEIGHT * ((N + TILE_HEIGHT - 1) / TILE_HEIGHT), - TILE_WIDTH * ((num_groups + TILE_WIDTH - 1) / TILE_WIDTH)}; + N, + num_groups}; - auto mean_rstd_padding = output_shape.value.padding(); - mean_rstd_padding[2] = Padding::PadDimension{0, TILE_HEIGHT - (N % TILE_HEIGHT)}; - mean_rstd_padding[3] = Padding::PadDimension{0, TILE_WIDTH - (num_groups % TILE_WIDTH)}; - - Shape mean_rstd_shape = Shape(tt::tt_metal::LegacyShape(mean_rstd_origin_shape, mean_rstd_padding)); + SimpleShape mean_rstd_shape(std::move(mean_rstd_origin_shape)); return {output_shape, mean_rstd_shape, mean_rstd_shape}; } diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/device/moreh_group_norm_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/device/moreh_group_norm_device_operation.hpp index 338cd28123a..480aac7cf01 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/device/moreh_group_norm_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/device/moreh_group_norm_device_operation.hpp @@ -27,7 +27,7 @@ struct MorehGroupNormOperation { const std::optional rstd; }; - using shape_return_value_t = std::vector>; + using shape_return_value_t = std::vector>; using tensor_return_value_t = std::vector>; struct MorehGroupNormFactory { diff --git a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.cpp b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.cpp index ea9e217f356..cafcb99c1be 100644 --- a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.cpp @@ -44,7 +44,7 @@ std::vector Prod::create_output_tensors(const std::vector& input return {}; } -std::vector Prod::compute_output_shapes(const std::vector& inputs) const { +std::vector Prod::compute_output_shapes(const std::vector& inputs) const { // Inplace return {}; diff --git a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.hpp b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.hpp index 7d92526127b..5552f120a4d 100644 --- a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.hpp @@ -23,7 +23,7 @@ using namespace tt_metal; struct Prod { int64_t dim; void validate(const std::vector &inputs) const; - std::vector compute_output_shapes(const std::vector &inputs) const; + std::vector compute_output_shapes(const std::vector &inputs) const; std::vector create_output_tensors(const std::vector &inputs) const; operation::ProgramWithCallbacks create_program( const std::vector &inputs, std::vector &outputs) const; diff --git a/ttnn/cpp/ttnn/run_operation.cpp b/ttnn/cpp/ttnn/run_operation.cpp index cfa2be30c4e..78366e52fef 100644 --- a/ttnn/cpp/ttnn/run_operation.cpp +++ b/ttnn/cpp/ttnn/run_operation.cpp @@ -302,6 +302,21 @@ template OptionalTensors run_without_autoformat( const OptionalTensors& optional_output_tensors, uint8_t cq_id); +std::vector extract_legacy_shapes( + const std::variant, std::vector>&& shapes, const std::function& layout_provider) { + if (std::holds_alternative>(shapes)) { + return std::get>(std::move(shapes)); + } + const auto& simple_shapes = std::get>(shapes); + std::vector legacy_shapes; + legacy_shapes.reserve(simple_shapes.size()); + for (size_t idx = 0; idx < simple_shapes.size(); idx++) { + auto layout = layout_provider(idx); + legacy_shapes.emplace_back(simple_shapes[idx].as_vector(), get_physical_shape(simple_shapes[idx], layout).as_vector()); + } + return legacy_shapes; +} + // To be deprecated/removed in favor of new implementation where ops specifically request how to format inputs/outputss Tensors run_with_autoformat( DeviceOperation&& operation, @@ -314,7 +329,9 @@ Tensors run_with_autoformat( using ttnn::operations::experimental::auto_format::AutoFormat; ZoneScoped; Device* device = detail::get_device(input_tensors, optional_input_tensors); - auto output_shapes = operation.compute_output_shapes(input_tensors); + auto output_shapes = extract_legacy_shapes(operation.compute_output_shapes(input_tensors), [](size_t) { + return Layout::TILE; + }); Tensors formatted_input_tensors; formatted_input_tensors.reserve(input_tensors.size()); @@ -372,7 +389,9 @@ Tensors run_with_autoformat( using ttnn::operations::experimental::auto_format::AutoFormat; ZoneScoped; Device* device = detail::get_device(input_tensors, optional_input_tensors); - auto output_shapes = operation.compute_output_shapes(input_tensors); + auto output_shapes = extract_legacy_shapes(operation.compute_output_shapes(input_tensors), [&](size_t idx) { + return output_layouts[idx]; + }); TT_ASSERT(input_tensors.size() == input_formatting.size()); TT_ASSERT(optional_input_tensors.size() == optional_input_formatting.size()); diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index 4f0fe5d95e2..642b2acec64 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -696,6 +696,11 @@ Tensor create_device_tensor( } } +Tensor create_device_tensor( + const ttnn::SimpleShape& logical_shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config, const std::optional& tile) { + return create_device_tensor(logical_shape, get_physical_shape(logical_shape, layout, tile), data_type, layout, device, memory_config, tile); +} + Tensor create_device_tensor( const ttnn::Shape& shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config, const std::optional& tile) { return create_device_tensor(shape.logical_shape(), shape.padded_shape(), data_type, layout, device, memory_config, tile); diff --git a/ttnn/cpp/ttnn/tensor/tensor.hpp b/ttnn/cpp/ttnn/tensor/tensor.hpp index e23832be836..4f201073097 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor.hpp @@ -293,6 +293,14 @@ struct Tensor { } }; +Tensor create_device_tensor( + const ttnn::SimpleShape &logical_shape, + DataType dtype, + Layout layout, + Device *device, + const MemoryConfig &memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, + const std::optional& tile = std::nullopt); + Tensor create_device_tensor( const ttnn::SimpleShape &logical_shape, const ttnn::SimpleShape &padded_shape, diff --git a/ttnn/cpp/ttnn/tensor/types.cpp b/ttnn/cpp/ttnn/tensor/types.cpp index ebf3001f5b8..7814c9d2a18 100644 --- a/ttnn/cpp/ttnn/tensor/types.cpp +++ b/ttnn/cpp/ttnn/tensor/types.cpp @@ -5,6 +5,31 @@ #include #include "ttnn/tensor/types.hpp" +namespace ttnn { + +SimpleShape get_physical_shape(const SimpleShape& logical_shape, Layout layout, const std::optional& tile) { + SimpleShape physical_shape = logical_shape; + if (layout == Layout::TILE) { + auto tile_height = tt::constants::TILE_HEIGHT; + auto tile_width = tt::constants::TILE_WIDTH; + if (tile.has_value()) { + auto tile_shape = tile.value().get_tile_shape(); + tile_height = tile_shape[0]; + tile_width = tile_shape[1]; + } + auto rank = physical_shape.rank(); + if (rank >= 1) { + physical_shape[rank - 1] = (physical_shape[rank - 1] + tile_width - 1) / tile_width * tile_width; + if (rank >= 2) { + physical_shape[rank - 2] = (physical_shape[rank - 2] + tile_height - 1) / tile_height * tile_height; + } + } + } + return physical_shape; +} + +} + namespace tt { namespace tt_metal { diff --git a/ttnn/cpp/ttnn/tensor/types.hpp b/ttnn/cpp/ttnn/tensor/types.hpp index baffe41d56c..2004f4fb19d 100644 --- a/ttnn/cpp/ttnn/tensor/types.hpp +++ b/ttnn/cpp/ttnn/tensor/types.hpp @@ -69,6 +69,8 @@ class SimpleShape { std::vector value; }; +SimpleShape get_physical_shape(const SimpleShape& logical_shape, Layout layout, const std::optional& tile = std::nullopt); + } // namespace ttnn inline std::ostream &operator<<(std::ostream &os, const ttnn::SimpleShape &shape) {