Skip to content

Commit

Permalink
#13127: Allow compute_output_shapes to use SimpleShape instead of L…
Browse files Browse the repository at this point in the history
…egacyShape, port some ops to SimpleShape (#13645)

* #13127: Prototype of moving some operation to ttnn::SimpleShape

* #13127: Port more ops

* #13127: Infra support for SimpleShape

* #13127: Revert convolution changes

* #13127: Refactor to remove code duplication

* #13127: Fix rebase issues

* #13127: Extract extract_legacy_shape function, make get_physical_shape pure function
  • Loading branch information
sminakov-tt authored Oct 9, 2024
1 parent df995f1 commit fa69b0b
Showing 39 changed files with 153 additions and 119 deletions.
8 changes: 4 additions & 4 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@ New Device Operation
struct <NewOperation> {
void validate(const std::vector<Tensor> &input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;
operation::ProgramWithCallbacks create_program(const std::vector<Tensor>& input_tensors, std::vector<Tensor> &output_tensors) const;
};
@@ -48,7 +48,7 @@ New Device Operation with a member
int some_member
void validate(const std::vector<Tensor> &input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;
operation::ProgramWithCallbacks create_program(const std::vector<Tensor>& input_tensors, std::vector<Tensor> &output_tensors) const;
};
@@ -61,7 +61,7 @@ New Device Operation with Optional Input Tensors
struct <NewOperation> {
void validate(const std::vector<Tensor> &input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor>& input_tensors,
@@ -80,7 +80,7 @@ and create_output_tensors with the additional parameter for the output_tensors.
struct <NewOperation> {
void validate_with_output_tensors(const std::vector<Tensor> &input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<std::optional<Tensor>> create_output_tensors(const std::vector<Tensor> &input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
operation::ProgramWithOptionalOutputTensors create_program(const std::vector<Tensor>& input_tensors, std::vector<std::optional<Tensor>> &output_tensors) const;
9 changes: 8 additions & 1 deletion tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp
Original file line number Diff line number Diff line change
@@ -27,7 +27,14 @@ std::vector<Tensor> run_operation(
const operation::OptionalTensors& optional_output_tensors = {}) {
static_assert(operation::detail::is_device_operation<OpConfig>(), "ttnn::run_operation can only dispatch Device Operations!");
// Create output tensor vector by examining the number of output shapes created by the device operation
std::vector<Tensor> outputs(operation::DeviceOperation<operation::Tensors>(devop).compute_output_shapes(input_tensors).size());
auto output_shapes = operation::DeviceOperation<operation::Tensors>(devop).compute_output_shapes(input_tensors);
size_t output_shapes_size = 0;
if (std::holds_alternative<std::vector<ttnn::SimpleShape>>(output_shapes)) {
output_shapes_size = std::get<std::vector<ttnn::SimpleShape>>(output_shapes).size();
} else {
output_shapes_size = std::get<std::vector<tt::tt_metal::LegacyShape>>(output_shapes).size();
}
std::vector<Tensor> outputs(output_shapes_size);
// Populate the workers of the output tensors, based on the input tensors. This is needed for the async engine.
for (int i = 0; i < outputs.size(); i++) {
outputs[i] = Tensor(operation::get_workers_for_op_output(std::move(input_tensors), std::move(optional_input_tensors)));
Original file line number Diff line number Diff line change
@@ -48,7 +48,7 @@ void MorehClipGradNormStep1::validate(
check_tensor(tmp_pow_sum, "moreh_clip_grad_norm_step1", "tmp_pow_sum");
};

std::vector<tt::tt_metal::LegacyShape> MorehClipGradNormStep1::compute_output_shapes(const std::vector<Tensor> &) const { return {}; }
std::vector<ttnn::SimpleShape> MorehClipGradNormStep1::compute_output_shapes(const std::vector<Tensor> &) const { return {}; }

std::vector<Tensor> MorehClipGradNormStep1::create_output_tensors(const std::vector<Tensor> &) const { return {}; }

@@ -105,7 +105,7 @@ void MorehClipGradNormStep2::validate(const std::vector<Tensor> &input_tensors)
check_tensor(total_norm, "moreh_clip_grad_norm_step2", "total_norm");
}

std::vector<tt::tt_metal::LegacyShape> MorehClipGradNormStep2::compute_output_shapes(const std::vector<Tensor> &) const { return {}; }
std::vector<ttnn::SimpleShape> MorehClipGradNormStep2::compute_output_shapes(const std::vector<Tensor> &) const { return {}; }

std::vector<Tensor> MorehClipGradNormStep2::create_output_tensors(const std::vector<Tensor> &) const { return {}; }

@@ -146,7 +146,7 @@ void MorehClipGradNormStep3::validate(
check_tensor(clip_coef_clamped, "moreh_clip_grad_norm_step3", "clip_coef_clamped");
}

std::vector<tt::tt_metal::LegacyShape> MorehClipGradNormStep3::compute_output_shapes(const std::vector<Tensor> &) const { return {}; }
std::vector<ttnn::SimpleShape> MorehClipGradNormStep3::compute_output_shapes(const std::vector<Tensor> &) const { return {}; }

std::vector<Tensor> MorehClipGradNormStep3::create_output_tensors(const std::vector<Tensor> &) const { return {}; }

Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@ struct MorehClipGradNormStep1 {
void validate(
const std::vector<Tensor> &input_tensors,
const std::vector<std::optional<const Tensor>> &optional_input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor> &) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor> &input_tensors,
@@ -49,7 +49,7 @@ struct MorehClipGradNormStep2 {
float norm_type;

void validate(const std::vector<Tensor> &input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor> &) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor> &input_tensors, std::vector<Tensor> &) const;
@@ -64,7 +64,7 @@ struct MorehClipGradNormStep3 {
void validate(
const std::vector<Tensor> &input_tensors,
const std::vector<std::optional<const Tensor>> &optional_input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor> &) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor> &input_tensors,
Original file line number Diff line number Diff line change
@@ -46,13 +46,11 @@ void MorehDot::validate(const std::vector<Tensor>& input_tensors) const {
"Operands to matmul need to be allocated in buffers on device!");
}

std::vector<tt::tt_metal::LegacyShape> MorehDot::compute_output_shapes(const std::vector<Tensor>& input_tensors) const {
std::vector<ttnn::SimpleShape> MorehDot::compute_output_shapes(const std::vector<Tensor>& input_tensors) const {
const auto& input_tensor = input_tensors.at(0);
auto output_shape = input_tensor.get_legacy_shape();
auto padding = output_shape.padding();
output_shape[3] = TILE_WIDTH;
padding[3] = Padding::PadDimension{0, 31};
return {tt::tt_metal::LegacyShape(output_shape, padding)};
auto output_shape = input_tensor.get_logical_shape();
output_shape[3] = 1;
return {output_shape};
}

std::vector<Tensor> MorehDot::create_output_tensors(const std::vector<Tensor>& input_tensors) const {
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@ struct MorehDot {
const DataType output_dtype; // TODO: Uplift output_dtype as an option for general dot/bmm

void validate(const std::vector<Tensor> &input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor> &input_tensors, std::vector<Tensor> &output_tensors) const;
Original file line number Diff line number Diff line change
@@ -63,7 +63,7 @@ void MorehDotBackward::validate(
}
}

std::vector<tt::tt_metal::LegacyShape> MorehDotBackward::compute_output_shapes(const std::vector<Tensor>& inputs) const {
std::vector<ttnn::SimpleShape> MorehDotBackward::compute_output_shapes(const std::vector<Tensor>& inputs) const {
// Inplace
return {};
}
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@ operation::ProgramWithCallbacks moreh_dot_backward_single_core(
struct MorehDotBackward {
void validate(
const std::vector<Tensor> &inputs, const std::vector<std::optional<const Tensor>> &optional_inputs) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &inputs) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor> &inputs) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &inputs) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor> &inputs,
Original file line number Diff line number Diff line change
@@ -405,39 +405,26 @@ void MorehLayerNorm::validate_with_output_tensors(
}
}

std::vector<tt::tt_metal::LegacyShape> MorehLayerNorm::compute_output_shapes(const std::vector<Tensor>& input_tensors) const {
std::vector<ttnn::SimpleShape> MorehLayerNorm::compute_output_shapes(const std::vector<Tensor>& input_tensors) const {
auto input = input_tensors.at(0);

// compute mean_rstd_shape
tt::tt_metal::LegacyShape input_shape = input.get_legacy_shape();
auto input_shape_without_padding = input_shape.without_padding();
auto input_shape = input.get_logical_shape();
auto input_rank = input_shape.rank();
auto output_rank = input_rank - normalized_dims;

std::vector<uint32_t> output_size_vec;
auto dimensions_pads = std::vector<Padding::PadDimension>();
std::vector<uint32_t> output_shape_vec;

// special case handling
if (output_rank == 1) {
output_size_vec.push_back(32);
dimensions_pads.push_back(Padding::PadDimension{.front = 0, .back = 31});
output_shape_vec.push_back(1);
}

for (uint32_t dim = 0 ; dim < output_rank; dim++) {
auto input_shape_without_padding_size = input_shape_without_padding[dim];
if (is_hw_dim(dim, output_rank)) {
output_size_vec.push_back(round_up_to_mul32(input_shape_without_padding_size));

auto padding_back = output_size_vec[dim] - input_shape_without_padding_size;
dimensions_pads.push_back(Padding::PadDimension{.front = 0, .back = padding_back});
} else {
output_size_vec.push_back(input_shape_without_padding_size);
dimensions_pads.push_back(Padding::PadDimension{.front = 0, .back = 0});
}
output_shape_vec.push_back(input_shape[dim]);
}

const auto padding = Padding(dimensions_pads, Padding::PadValue::Any);
auto mean_rstd_output_shape = tt::tt_metal::LegacyShape(output_size_vec, padding);
ttnn::SimpleShape mean_rstd_output_shape(std::move(output_shape_vec));

return {input_shape, mean_rstd_output_shape, mean_rstd_output_shape};
}
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@ struct MorehLayerNorm {
const std::vector<Tensor> &input_tensors,
const std::vector<std::optional<const Tensor>> &optional_input_tensors,
const std::vector<std::optional<Tensor>> &output_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
operation::ProgramWithCallbacks create_program(
Original file line number Diff line number Diff line change
@@ -62,10 +62,10 @@ void MorehLayerNormBackwardInputGrad::validate_with_output_tensors(
}
}

std::vector<tt::tt_metal::LegacyShape> MorehLayerNormBackwardInputGrad::compute_output_shapes(
std::vector<ttnn::SimpleShape> MorehLayerNormBackwardInputGrad::compute_output_shapes(
const std::vector<Tensor>& input_tensors) const {
auto input = input_tensors.at(0);
auto input_shape = input.get_legacy_shape();
auto input_shape = input.get_logical_shape();

// The shapes of the input and output are always the same.
return {input_shape};
@@ -131,7 +131,7 @@ void MorehLayerNormBackwardGammaBetaGrad::validate_with_output_tensors(
}
}

std::vector<tt::tt_metal::LegacyShape> MorehLayerNormBackwardGammaBetaGrad::compute_output_shapes(
std::vector<ttnn::SimpleShape> MorehLayerNormBackwardGammaBetaGrad::compute_output_shapes(
const std::vector<Tensor>& input_tensors) const {
TT_THROW("The compute_output_shapes function in MorehLayerNormBackwardGammaBetaGrad is not implemented.");
return {};
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@ struct MorehLayerNormBackwardInputGrad {
const std::vector<Tensor> &input_tensors,
const std::vector<std::optional<const Tensor>> &optional_input_tensors,
const std::vector<std::optional<Tensor>>& output_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor> &input_tensors,
@@ -46,7 +46,7 @@ struct MorehLayerNormBackwardGammaBetaGrad {
void validate_with_output_tensors(
const std::vector<Tensor> &input_tensors,
const std::vector<std::optional<Tensor>>& output_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor> &input_tensors, std::vector<Tensor> &output_tensors) const;
Original file line number Diff line number Diff line change
@@ -32,18 +32,16 @@ inline bool is_dot_forward(const Tensor& input, const Tensor& other, bool transp
return is_1d_tensor(input) && is_1d_tensor(other) && is_same_shape(input, other);
}

tt::tt_metal::LegacyShape compute_output_shape(
ttnn::SimpleShape compute_output_shape(
const tt::tt_metal::LegacyShape& input_shape,
const tt::tt_metal::LegacyShape& other_shape,
bool transpose_input,
bool transpose_other) {
const auto& input_shape_wo_padding = input_shape.without_padding();
const auto& other_shape_wo_padding = other_shape.without_padding();
const auto& logical_input_shape = input_shape.logical_shape();
const auto& logical_other_shape = other_shape.logical_shape();

auto h = (transpose_input) ? (input_shape[-1]) : (input_shape[-2]);
auto w = (transpose_other) ? (other_shape[-2]) : (other_shape[-1]);
auto h_wo_padding = (transpose_input) ? (input_shape_wo_padding[-1]) : (input_shape_wo_padding[-2]);
auto w_wo_padding = (transpose_other) ? (other_shape_wo_padding[-2]) : (other_shape_wo_padding[-1]);
auto h = (transpose_input) ? (logical_input_shape[-1]) : (logical_input_shape[-2]);
auto w = (transpose_other) ? (logical_other_shape[-2]) : (logical_other_shape[-1]);

std::vector<uint32_t> input_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1);
std::vector<uint32_t> other_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1);
@@ -72,12 +70,7 @@ tt::tt_metal::LegacyShape compute_output_shape(
output_dim[output_rank - 2] = h;
output_dim[output_rank - 1] = w;

tt::tt_metal::LegacyShape output_shape{output_dim};
auto padding = output_shape.padding();
// padding for t logmatrix dims
padding[output_rank - 2] = Padding::PadDimension{0, h - h_wo_padding};
padding[output_rank - 1] = Padding::PadDimension{0, w - w_wo_padding};
return {tt::tt_metal::LegacyShape(output_shape, padding)};
return {ttnn::SimpleShape(std::move(output_dim))};
}

} // namespace
@@ -159,7 +152,7 @@ operation::ProgramWithCallbacks MorehMatmul::create_program(
}

// Must be provided in the case where an optional output tensor was not provided
std::vector<tt::tt_metal::LegacyShape> MorehMatmul::compute_output_shapes(
std::vector<ttnn::SimpleShape> MorehMatmul::compute_output_shapes(
const std::vector<Tensor>& input_tensors) const {
return {compute_output_shape(
input_tensors.at(0).get_legacy_shape(),
Original file line number Diff line number Diff line change
@@ -39,7 +39,7 @@ struct MorehMatmul {
const std::vector<Tensor> &input_tensors,
const std::vector<std::optional<const Tensor>> &optional_input_tensors,
const std::vector<std::optional<Tensor>> &output_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(
const std::vector<Tensor> &input_tensors, const std::vector<std::optional<Tensor>> &output_tensors) const;
operation::ProgramWithCallbacks create_program(
Loading

0 comments on commit fa69b0b

Please sign in to comment.