Skip to content

Commit

Permalink
#13127: Switch create_device_tensor to automatically figure out paddi…
Browse files Browse the repository at this point in the history
…ng for tile layout and if only height and width are padded to nearest tile

- Add warning for user when trying to create row_major layout device tensor with any padding
- Add warning for user when trying to create tile layout device tensor with padding along non-height or width
- Add warning for user when trying to create tile layout device tensor with padding along height or width that exceeds the nearest tile
  • Loading branch information
TT-BrianLiu committed Oct 11, 2024
1 parent ab62c47 commit d191fc0
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion ttnn/cpp/ttnn/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,27 @@ Tensor create_device_tensor(

Tensor create_device_tensor(
const ttnn::Shape& shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config, const std::optional<Tile>& tile) {
return create_device_tensor(shape.logical_shape(), shape.padded_shape(), data_type, layout, device, memory_config, tile);

if (layout == Layout::ROW_MAJOR) {
if (shape.has_tile_padding()) {
tt::log_warning("ttnn::Shape {} represents a row_major tensor with padding! Falling back to pass logical and padded shape to create_device_tensor.", shape);
return create_device_tensor(shape.logical_shape(), shape.padded_shape(), data_type, layout, device, memory_config, tile);
}
} else {
for (size_t dim = 0; dim < shape.rank(); dim++) {
if (dim < shape.rank() - 2) {
if (shape.has_tile_padding(dim)) {
tt::log_warning("ttnn::Shape {} has padding along dims that are not height and width! Falling back to pass logical and padded shape to create_device_tensor.", shape);
return create_device_tensor(shape.logical_shape(), shape.padded_shape(), data_type, layout, device, memory_config, tile);
}
} else if (shape.padded_shape()[dim] - shape.logical_shape()[dim] >= ttnn::TILE_SIZE) {
// NOTE: This also covers the case where logical dim 0 is padded up to 32
tt::log_warning("ttnn::Shape {} has padding along height or width that exceeds nearest tile! Falling back to pass logical and padded shape to create_device_tensor.", shape);
return create_device_tensor(shape.logical_shape(), shape.padded_shape(), data_type, layout, device, memory_config, tile);
}
}
}
return create_device_tensor(shape.logical_shape(), data_type, layout, device, memory_config, tile);
}

namespace detail {
Expand Down

0 comments on commit d191fc0

Please sign in to comment.