Skip to content

Commit

Permalink
#0: Fix sweep
Browse files Browse the repository at this point in the history
  • Loading branch information
sankarmanoj-tt committed Feb 25, 2025
1 parent 361792d commit 4d5afb1
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,8 +694,6 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases
uint32_t in_channels_padded = tt::round_up(in_channels, input_num_cores_channels * input_channels_alignment);
uint32_t out_channel_padding = out_channels_padded - out_channels;

ttnn::Shape weights_channels_padded_shape(
std::array<uint32_t, 4>({out_channels_padded, in_channels_padded, window_h, window_w}));
if (weights_bias_dtype == DataType::BFLOAT8_B) {
TT_ASSERT(weight_tensor_.get_dtype() == DataType::FLOAT32);
if (bias_tensor.has_value()) {
Expand All @@ -711,9 +709,14 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases

// Block sharding re-orders the weights by dividing the input_channels along number of in_channel_cores.
if (input_parallel_config.shard_scheme == TensorMemoryLayout::BLOCK_SHARDED) {
weight_tensor_ = ttnn::permute(weight_tensor_, ttnn::SmallVector<int64_t>({2, 3, 1, 0}));

ttnn::Shape weights_channels_padded_shape(
std::array<uint32_t, 4>({window_h, window_w, out_channels_padded, in_channels_padded}));

weight_tensor_ = ttnn::pad(
weight_tensor_,
weights_channels_padded_shape.to_array_4D(),
tt::tt_metal::Array4D({window_h, window_w, in_channels_padded, out_channels_padded}),
tt::tt_metal::Array4D({0, 0, 0, 0}),
0.0f,
true,
Expand Down Expand Up @@ -745,12 +748,12 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases
weight_tensor_ = ttnn::reshape(
weight_tensor_,
ttnn::Shape(
{output_num_cores_channels, out_channels_per_core, in_channels_padded * window_h, window_w}));
{in_channels_padded * window_h, window_w, output_num_cores_channels, out_channels_per_core}));

weight_tensor_ = ttnn::pad(
weight_tensor_,
tt::tt_metal::Array4D(
{output_num_cores_channels, rounded_weight_block_width, in_channels_padded * window_h, window_w}),
{in_channels_padded * window_h, window_w, output_num_cores_channels, rounded_weight_block_width}),
tt::tt_metal::Array4D({0, 0, 0, 0}),
0,
true,
Expand All @@ -759,15 +762,13 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases
weight_tensor_ = ttnn::reshape(
weight_tensor_,
ttnn::Shape(
{final_out_channels_padded, input_num_cores_channels, in_channels_per_core, window_h, window_w}));

weight_tensor_ = ttnn::permute(weight_tensor_, ttnn::SmallVector<int64_t>({1, 3, 4, 2, 0}));
// Shape is now {input_num_cores_channels, window_h, window_w, in_channels_per_core, out_channels_padded}
{window_h, window_w, input_num_cores_channels, in_channels_per_core, final_out_channels_padded}));

weight_tensor_ = ttnn::permute(weight_tensor_, ttnn::SmallVector<int64_t>({2, 0, 1, 3, 4}));
weight_tensor_ = ttnn::reshape(
weight_tensor_,
ttnn::Shape(
{1, input_num_cores_channels, in_channels_per_core * window_h * window_w, final_out_channels_padded}));
{1, input_num_cores_channels, window_h * window_w * in_channels_per_core, final_out_channels_padded}));
weight_tensor_ = ttnn::pad(
weight_tensor_,
tt::tt_metal::Array4D(
Expand All @@ -776,12 +777,16 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases
0,
true,
std::nullopt);

weight_tensor_ = ttnn::reshape(
weight_tensor_,
ttnn::Shape({1, 1, rounded_weight_block_height * input_num_cores_channels, final_out_channels_padded}));
} else {
weight_tensor_ = ttnn::permute(weight_tensor_, ttnn::SmallVector<int64_t>({2, 3, 1, 0}));

ttnn::Shape weights_channels_padded_shape(
std::array<uint32_t, 4>({window_h, window_w, out_channels_padded, in_channels_padded}));

weight_tensor_ = ttnn::pad(
weight_tensor_,
tt::tt_metal::Array4D({window_h, window_w, in_channels_padded, out_channels_padded}),
Expand All @@ -790,7 +795,6 @@ std::pair<ttnn::Tensor, std::optional<ttnn::Tensor>> prepare_conv_weights_biases
true,
std::nullopt);

// Shape is now {1, window_h, window_w, in_channels_padded, out_channels_padded}
auto weight_block_h_datums = weight_block_h_ntiles * constants::TILE_HEIGHT;
if ((weight_block_h_datums > (window_w * in_channels_padded)) &&
(input_parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED)) {
Expand Down

0 comments on commit 4d5afb1

Please sign in to comment.