Skip to content

Commit

Permalink
Enhance functionality of max_pool2d op (#2176)
Browse files Browse the repository at this point in the history
### Ticket
#2171

### Problem description
TTIR to TTIR decomposition of ttir.pooling to ttir.max_pool2d only
handle cases where `kernel size` and `strides` are greater than 1

### What's changed
TTIR to TTIR decomposition pass modified to handle following cases
* Kernel size of 1
* Kernel size of (1, n) or (n, 1)
* Stride size of 1
* Stride size of (1, n) or (n, 1)

### Checklist
- [X] New tests provide coverage for changes
  • Loading branch information
mmanzoorTT authored Feb 21, 2025
1 parent 04758dc commit 4e8b4c3
Show file tree
Hide file tree
Showing 4 changed files with 322 additions and 105 deletions.
185 changes: 93 additions & 92 deletions lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -831,77 +831,108 @@ struct PoolingToPool2dPattern : public OpConversionPattern<ttir::PoolingOp> {
public:
using OpConversionPattern<ttir::PoolingOp>::OpConversionPattern;

std::vector<int64_t> getIndicesOfSpatialDims(ttir::PoolingOp op) const {
std::vector<int64_t> spatialDims;
for (int64_t i = 0;
i < static_cast<int64_t>(op.getWindowDimensions().size()); i++) {
if (op.getWindowDimensions()[i] > 1) {
spatialDims.push_back(i);
LogicalResult
matchAndRewrite(ttir::PoolingOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
llvm::SmallVector<int64_t> spatialDimIndices =
getIndicesOfElementsLargerThanOne(op.getWindowDimensions());
size_t numSpatialDimIndices = spatialDimIndices.size();
if (numSpatialDimIndices > 2) {
return rewriter.notifyMatchFailure(
op, "No decompositions for a pooling op with " +
std::to_string(numSpatialDimIndices) + " spatial dimensions");
}

LogicalResult legalityResult =
canDecompose2DPoolingOp(op, rewriter, spatialDimIndices);
if (!legalityResult.succeeded()) {
return legalityResult;
}

switch (op.getPoolingMethod()) {
case ttir::PoolingMethod::Max: {
rewritePool2d<ttir::MaxPool2dOp>(op, adaptor, rewriter,
spatialDimIndices);
return success();
}
default: {
return rewriter.notifyMatchFailure(
op, "Failed to match pooling method: " +
stringifyPoolingMethod(op.getPoolingMethod()));
}
}
}

private:
llvm::SmallVector<int64_t>
getIndicesOfElementsLargerThanOne(llvm::ArrayRef<int64_t> input) const {
llvm::SmallVector<int64_t, 2> result;
for (size_t i = 0; i < input.size(); i++) {
if (input[i] > 1) {
result.push_back(i);
}
}
return spatialDims;
return result;
}

LogicalResult canDecompose2DPoolingOp(ttir::PoolingOp op) const {
LogicalResult
canDecompose2DPoolingOp(ttir::PoolingOp op,
ConversionPatternRewriter &rewriter,
llvm::SmallVector<int64_t> spatialDimIndices) const {

// Window dimensions must be 4 in length
if (op.getWindowDimensions().size() != 4) {
return failure();
return rewriter.notifyMatchFailure(
op, "Polling 2D op is only supported for 4D tensor.");
}

// Window strides must be 4 in length
if (op.getWindowStrides().size() != 4) {
return failure();
return rewriter.notifyMatchFailure(
op, "Polling 2D op is only supported for 4D tensor.");
}

// Operand rank(s) must be 4
for (Value operand : op.getInputs()) {
auto operandType = mlir::cast<mlir::RankedTensorType>(operand.getType());
if (operandType.getRank() != 4) {
return failure();
}
}

// Exactly two of the window dimensions must be greater than 1
std::vector<int64_t> trueWindowDimensionsIndices =
getIndicesOfSpatialDims(op);

if (trueWindowDimensionsIndices.size() != 2) {
return failure();
}

// Exactly two of the window strides must be greater than 1
std::vector<int64_t> trueWindowStrideIndices;
for (int64_t i = 0; i < static_cast<int64_t>(op.getWindowStrides().size());
i++) {
if (op.getWindowStrides()[i] > 1) {
trueWindowStrideIndices.push_back(i);
return rewriter.notifyMatchFailure(
op, "Polling 2D op is only supported for 4D tensor.");
}
}

if (trueWindowStrideIndices.size() != 2) {
return failure();
// Window dimensions will have two or less than two non 1 elements;
// representing the kernel size for max pooling operation.
size_t numSpatialDimIndices = spatialDimIndices.size();
if (numSpatialDimIndices > 2) {
return rewriter.notifyMatchFailure(
op, "Rank of kernel_size for pooling 2D op is greater than 2.");
}

// The indices of the true window dimensions and strides must be the same
if ((trueWindowDimensionsIndices[0] != trueWindowStrideIndices[0] ||
trueWindowDimensionsIndices[1] != trueWindowStrideIndices[1]) &&
(trueWindowDimensionsIndices[0] != trueWindowStrideIndices[1] ||
trueWindowDimensionsIndices[1] != trueWindowStrideIndices[0])) {
return failure();
// Window strides will have two or less than two non 1 elements;
// representing the strides for max pooling operation.
llvm::SmallVector<int64_t> trueWindowStrideIndices =
getIndicesOfElementsLargerThanOne(op.getWindowStrides());
size_t windowStrideSize = trueWindowStrideIndices.size();
if (windowStrideSize > 2) {
return rewriter.notifyMatchFailure(
op, "Rank of strides for pooling 2D is greater than 2.");
}

// Padding must be 8 in length
if (op.getPadding().size() != 8) {
return failure();
return rewriter.notifyMatchFailure(
op,
"Number of elements in padding does not match with pooling 2D op.");
}

return success();
}

template <typename PoolOpType>
void rewritePool2d(ttir::PoolingOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
ConversionPatternRewriter &rewriter,
llvm::SmallVector<int64_t> spatialDimIndices) const {

const int64_t SPATIAL_H = -3;
const int64_t SPATIAL_W = -2;
Expand All @@ -922,11 +953,20 @@ struct PoolingToPool2dPattern : public OpConversionPattern<ttir::PoolingOp> {
}
}

std::vector<int64_t> spatialDims = getIndicesOfSpatialDims(op);
int64_t numWinDims = op.getWindowDimensions().size();
// Using default indices for channel first tensor if window dimension
// attribute does not contain two non 1 elements for kernel size.
// [TODO] (mmanzoor) Add an option to distingush channel first vs channel
// last and support channel last default indices.
// https://github.com/tenstorrent/tt-mlir/issues/2237
spatialDimIndices =
(spatialDimIndices.size() == 2)
? spatialDimIndices
: llvm::SmallVector<int64_t>({numWinDims - 2, numWinDims - 1});

std::vector<int64_t> currentLayout(inputType.getRank(), NON_SPATIAL);
currentLayout[spatialDims[0]] = SPATIAL_H;
currentLayout[spatialDims[1]] = SPATIAL_W;
currentLayout[spatialDimIndices[0]] = SPATIAL_H;
currentLayout[spatialDimIndices[1]] = SPATIAL_W;

nonSpatialCount = 0;
for (int64_t i = 0; i < static_cast<int64_t>(currentLayout.size()); i++) {
Expand All @@ -941,30 +981,30 @@ struct PoolingToPool2dPattern : public OpConversionPattern<ttir::PoolingOp> {
auto inverseOfPermutation = ttmlir::utils::inversePermutation(permutation);

auto kernelHeightAttr = rewriter.getSI32IntegerAttr(
static_cast<int32_t>(op.getWindowDimensions()[spatialDims[0]]));
static_cast<int32_t>(op.getWindowDimensions()[spatialDimIndices[0]]));
auto kernelWidthAttr = rewriter.getSI32IntegerAttr(
static_cast<int32_t>(op.getWindowDimensions()[spatialDims[1]]));
static_cast<int32_t>(op.getWindowDimensions()[spatialDimIndices[1]]));

auto strideHeightAttr = rewriter.getSI32IntegerAttr(
static_cast<int32_t>(op.getWindowStrides()[spatialDims[0]]));
static_cast<int32_t>(op.getWindowStrides()[spatialDimIndices[0]]));

auto strideWidthAttr = rewriter.getSI32IntegerAttr(
static_cast<int32_t>(op.getWindowStrides()[spatialDims[1]]));
static_cast<int32_t>(op.getWindowStrides()[spatialDimIndices[1]]));

auto dilationHeightAttr = rewriter.getSI32IntegerAttr(
adaptor.getWindowDilations()[spatialDims[0]]);
adaptor.getWindowDilations()[spatialDimIndices[0]]);
auto dilationWidthAttr = rewriter.getSI32IntegerAttr(
adaptor.getWindowDilations()[spatialDims[1]]);
adaptor.getWindowDilations()[spatialDimIndices[1]]);
auto ceilModeAttr = rewriter.getBoolAttr(false);

auto paddingTopAttr =
rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatialDims[0]]);
auto paddingBottomAttr =
rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatialDims[0] + 1]);
rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatialDimIndices[0]]);
auto paddingBottomAttr = rewriter.getSI32IntegerAttr(
op.getPadding()[2 * spatialDimIndices[0] + 1]);
auto paddingLeftAttr =
rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatialDims[1]]);
auto paddingRightAttr =
rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatialDims[1] + 1]);
rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatialDimIndices[1]]);
auto paddingRightAttr = rewriter.getSI32IntegerAttr(
op.getPadding()[2 * spatialDimIndices[1] + 1]);

llvm::SmallVector<Value> outputs;
for (Value input : adaptor.getInputs()) {
Expand Down Expand Up @@ -999,45 +1039,6 @@ struct PoolingToPool2dPattern : public OpConversionPattern<ttir::PoolingOp> {

rewriter.replaceOp(op, outputs);
}

uint32_t getNumSpatialDims(ttir::PoolingOp op) const {
uint32_t numSpatialDims = 0;
for (int64_t dim : op.getWindowDimensions()) {
if (dim > 1) {
numSpatialDims++;
}
}
return numSpatialDims;
}

LogicalResult
matchAndRewrite(ttir::PoolingOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

uint32_t numSpatialDims = getNumSpatialDims(op);
if (numSpatialDims == 2) {
if (failed(canDecompose2DPoolingOp(op))) {
return rewriter.notifyMatchFailure(
op, "2D pooling op with the given attributes is not supported "
"currently");
}

switch (op.getPoolingMethod()) {
case ttir::PoolingMethod::Max: {
rewritePool2d<ttir::MaxPool2dOp>(op, adaptor, rewriter);
return success();
}
default: {
return rewriter.notifyMatchFailure(
op, "Failed to match pooling method: " +
stringifyPoolingMethod(op.getPoolingMethod()));
}
}
}
return rewriter.notifyMatchFailure(
op, "No decompositions for a pooling op with " +
std::to_string(numSpatialDims) + " spatial dimensions");
}
};
} // namespace

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ struct TTIRToTTIRDecompositionPass
// func.func and
// func.call as legal ops
target.addLegalDialect<BuiltinDialect>(); // This contains the "module" op
// which is necesarry
// which is necessary

target.addLegalOp<tensor::EmptyOp>(); // DPS operands are create with
// tensor::EmptyOp
Expand Down
103 changes: 103 additions & 0 deletions test/ttmlir/Decomposition/TTIR/pooling/max_pool2d.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// RUN: ttmlir-opt --ttir-to-ttir-decomposition %s | FileCheck %s

module attributes {} {
// Kernel size = 1; stride = 1
func.func @test_maxpool2d_kernel_1x1_stride_1x1(%arg0: tensor<1x192x28x28xbf16>) -> tensor<1x192x28x28xbf16> {
// CHECK-LABEL: func.func @test_maxpool2d_kernel_1x1_stride_1x1(
%0 = tensor.empty() : tensor<1x192x28x28xbf16>
// CHECK: %[[PERMUTE:[0-9]+]] = "ttir.permute"(%arg0
// CHECK-SAME: permutation = array<i64: 0, 2, 3, 1>
// CHECK-SAME: (tensor<1x192x28x28xbf16>, tensor<1x28x28x192xbf16>)
// CHECK-SAME: -> tensor<1x28x28x192xbf16>
// CHECK: %[[MAXPOOL:[0-9]+]] = "ttir.max_pool2d"(%[[PERMUTE]],
// CHECK-SAME: ceil_mode = false,
// CHECK-SAME: dilation_height = 1 : si32, dilation_width = 1 : si32,
// CHECK-SAME: kernel_height = 1 : si32, kernel_width = 1 : si32,
// CHECK-SAME: padding_bottom = 0 : si32, padding_left = 0 : si32, padding_right = 0 : si32, padding_top = 0 : si32,
// CHECK-SAME: stride_height = 1 : si32, stride_width = 1 : si32
// CHECK-SAME: (tensor<1x28x28x192xbf16>, tensor<1x28x28x192xbf16>)
// CHECK-SAME: -> tensor<1x28x28x192xbf16>
%1 = "ttir.pooling"(%arg0, %0) <{base_dilations = array<i64: 1, 1, 1, 1>, operandSegmentSizes = array<i32: 1, 1>, padding = array<i64: 0, 0, 0, 0, 0, 0, 0, 0>, pooling_method = #ttir<pooling_method Max>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 1, 1>, window_strides = array<i64: 1, 1, 1, 1>}> : (tensor<1x192x28x28xbf16>, tensor<1x192x28x28xbf16>) -> tensor<1x192x28x28xbf16>
// CHECK: %[[RET:[0-9]+]] = "ttir.permute"(%[[MAXPOOL]],
// CHECK-SAME: permutation = array<i64: 0, 3, 1, 2>
// CHECK-SAME: (tensor<1x28x28x192xbf16>, tensor<1x192x28x28xbf16>)
// CHECK-SAME: -> tensor<1x192x28x28xbf16>
// CHECK: return %[[RET]] : tensor<1x192x28x28xbf16>
return %1 : tensor<1x192x28x28xbf16>
}

// Kernel size = 3; stride = 1
func.func @test_maxpool2d_kernel_3x3_stride_1x1(%arg0: tensor<1x256x28x28xbf16>) -> tensor<1x256x28x28xbf16> {
// CHECK-LABEL: func.func @test_maxpool2d_kernel_3x3_stride_1x1(
%0 = tensor.empty() : tensor<1x256x28x28xbf16>
// CHECK: %[[PERMUTE:[0-9]+]] = "ttir.permute"(%arg0
// CHECK-SAME: permutation = array<i64: 0, 2, 3, 1>
// CHECK-SAME: (tensor<1x256x28x28xbf16>, tensor<1x28x28x256xbf16>)
// CHECK-SAME: -> tensor<1x28x28x256xbf16>
// CHECK: %[[MAXPOOL:[0-9]+]] = "ttir.max_pool2d"(%[[PERMUTE]],
// CHECK-SAME: ceil_mode = false,
// CHECK-SAME: dilation_height = 1 : si32, dilation_width = 1 : si32,
// CHECK-SAME: kernel_height = 3 : si32, kernel_width = 3 : si32,
// CHECK-SAME: padding_bottom = 1 : si32, padding_left = 1 : si32, padding_right = 1 : si32, padding_top = 1 : si32,
// CHECK-SAME: stride_height = 1 : si32, stride_width = 1 : si32
// CHECK-SAME: (tensor<1x28x28x256xbf16>, tensor<1x28x28x256xbf16>)
// CHECK-SAME: -> tensor<1x28x28x256xbf16>
%1 = "ttir.pooling"(%arg0, %0) <{base_dilations = array<i64: 1, 1, 1, 1>, operandSegmentSizes = array<i32: 1, 1>, padding = array<i64: 0, 0, 0, 0, 1, 1, 1, 1>, pooling_method = #ttir<pooling_method Max>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 3, 3>, window_strides = array<i64: 1, 1, 1, 1>}> : (tensor<1x256x28x28xbf16>, tensor<1x256x28x28xbf16>) -> tensor<1x256x28x28xbf16>
// CHECK: %[[RET:[0-9]+]] = "ttir.permute"(%[[MAXPOOL]],
// CHECK-SAME: permutation = array<i64: 0, 3, 1, 2>
// CHECK-SAME: (tensor<1x28x28x256xbf16>, tensor<1x256x28x28xbf16>)
// CHECK-SAME: -> tensor<1x256x28x28xbf16>
// CHECK: return %[[RET]] : tensor<1x256x28x28xbf16>
return %1 : tensor<1x256x28x28xbf16>
}

// Kernel size = (2, 1); stride = 1
func.func @test_maxpool2d_kernel_2x1_stride_1x1(%arg0: tensor<1x192x28x28xbf16>) -> tensor<1x192x27x28xbf16> {
// CHECK-LABEL: func.func @test_maxpool2d_kernel_2x1_stride_1x1(
%0 = tensor.empty() : tensor<1x192x27x28xbf16>
// CHECK: %[[PERMUTE:[0-9]+]] = "ttir.permute"(%arg0
// CHECK-SAME: permutation = array<i64: 0, 2, 3, 1>
// CHECK-SAME: (tensor<1x192x28x28xbf16>, tensor<1x28x28x192xbf16>)
// CHECK-SAME: -> tensor<1x28x28x192xbf16>
// CHECK: %[[MAXPOOL:[0-9]+]] = "ttir.max_pool2d"(%[[PERMUTE]],
// CHECK-SAME: ceil_mode = false,
// CHECK-SAME: dilation_height = 1 : si32, dilation_width = 1 : si32,
// CHECK-SAME: kernel_height = 2 : si32, kernel_width = 1 : si32,
// CHECK-SAME: padding_bottom = 0 : si32, padding_left = 0 : si32, padding_right = 0 : si32, padding_top = 0 : si32,
// CHECK-SAME: stride_height = 1 : si32, stride_width = 1 : si32
// CHECK-SAME: (tensor<1x28x28x192xbf16>, tensor<1x27x28x192xbf16>)
// CHECK-SAME: -> tensor<1x27x28x192xbf16>
%1 = "ttir.pooling"(%arg0, %0) <{base_dilations = array<i64: 1, 1, 1, 1>, operandSegmentSizes = array<i32: 1, 1>, padding = array<i64: 0, 0, 0, 0, 0, 0, 0, 0>, pooling_method = #ttir<pooling_method Max>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 2, 1>, window_strides = array<i64: 1, 1, 1, 1>}> : (tensor<1x192x28x28xbf16>, tensor<1x192x27x28xbf16>) -> tensor<1x192x27x28xbf16>
// CHECK: %[[RET:[0-9]+]] = "ttir.permute"(%[[MAXPOOL]],
// CHECK-SAME: permutation = array<i64: 0, 3, 1, 2>
// CHECK-SAME: (tensor<1x27x28x192xbf16>, tensor<1x192x27x28xbf16>)
// CHECK-SAME: -> tensor<1x192x27x28xbf16>
// CHECK: return %[[RET]] : tensor<1x192x27x28xbf16>
return %1 : tensor<1x192x27x28xbf16>
}

// Kernel size = (1, 2); stride = (3, 1)
func.func @test_maxpool2d_kernel_1x2_stride_3x1(%arg0: tensor<1x192x28x28xbf16>) -> tensor<1x192x10x27xbf16> {
// CHECK-LABEL: func.func @test_maxpool2d_kernel_1x2_stride_3x1(
%0 = tensor.empty() : tensor<1x192x10x27xbf16>
// CHECK: %[[PERMUTE:[0-9]+]] = "ttir.permute"(%arg0
// CHECK-SAME: permutation = array<i64: 0, 2, 3, 1>
// CHECK-SAME: (tensor<1x192x28x28xbf16>, tensor<1x28x28x192xbf16>)
// CHECK-SAME: -> tensor<1x28x28x192xbf16>
// CHECK: %[[MAXPOOL:[0-9]+]] = "ttir.max_pool2d"(%[[PERMUTE]],
// CHECK-SAME: ceil_mode = false,
// CHECK-SAME: dilation_height = 1 : si32, dilation_width = 1 : si32,
// CHECK-SAME: kernel_height = 1 : si32, kernel_width = 2 : si32,
// CHECK-SAME: padding_bottom = 0 : si32, padding_left = 0 : si32, padding_right = 0 : si32, padding_top = 0 : si32,
// CHECK-SAME: stride_height = 3 : si32, stride_width = 1 : si32
// CHECK-SAME: (tensor<1x28x28x192xbf16>, tensor<1x10x27x192xbf16>)
// CHECK-SAME: -> tensor<1x10x27x192xbf16>
%1 = "ttir.pooling"(%arg0, %0) <{base_dilations = array<i64: 1, 1, 1, 1>, operandSegmentSizes = array<i32: 1, 1>, padding = array<i64: 0, 0, 0, 0, 0, 0, 0, 0>, pooling_method = #ttir<pooling_method Max>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 1, 2>, window_strides = array<i64: 1, 1, 3, 1>}> : (tensor<1x192x28x28xbf16>, tensor<1x192x10x27xbf16>) -> tensor<1x192x10x27xbf16>
// CHECK: %[[RET:[0-9]+]] = "ttir.permute"(%[[MAXPOOL]],
// CHECK-SAME: permutation = array<i64: 0, 3, 1, 2>
// CHECK-SAME: (tensor<1x10x27x192xbf16>, tensor<1x192x10x27xbf16>)
// CHECK-SAME: -> tensor<1x192x10x27xbf16>
// CHECK: return %[[RET]] : tensor<1x192x10x27xbf16>
return %1 : tensor<1x192x10x27xbf16>
}
}
Loading

0 comments on commit 4e8b4c3

Please sign in to comment.