Skip to content

Commit

Permalink
Convert tt.reshape to tensor.collapse_shape or tensor.expand_shape wh…
Browse files Browse the repository at this point in the history
…en possible
  • Loading branch information
Xiaoran Weng committed Dec 2, 2024
1 parent 3fe82cb commit 1657ad2
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"

#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/TypeSwitch.h"
Expand Down Expand Up @@ -1849,10 +1850,25 @@ class ReshapeConverter : public OpConversionPattern<triton::ReshapeOp> {
auto input = op.getSrc();
auto output = op.getResult();

auto outputType = dyn_cast<RankedTensorType>(output.getType());
if (!outputType) {
auto inputType = input.getType();
auto outputType = output.getType();
if (!outputType.hasStaticShape()) {
return failure();
}

if (auto maybeReassociationMap =
getReassociationIndicesForReshape(inputType, outputType)) {
auto reassociationMap = *maybeReassociationMap;
if (outputType.getRank() < inputType.getRank()) {
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
op, outputType, input, reassociationMap);
} else {
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
op, outputType, input, reassociationMap);
}
return success();
}

ArrayRef<int64_t> outputShape = outputType.getShape();

auto shape = rewriter.create<arith::ConstantOp>(
Expand Down
12 changes: 4 additions & 8 deletions test/Conversion/StructuredToMemref/convert_tensor_reshape.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,21 @@ module {
// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func.func @bcast_kernel_01
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) {
// CHECK-DAG: [[CST_2048_:%.+]] = arith.constant 2048 : i64
// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<[1, 32]> : tensor<2xi64>
// CHECK-DAG: [[CST_32_:%.+]] = arith.constant 32 : i32
// CHECK: [[VAR_0_:%.+]] = arith.muli [[PARAM_5_]], [[CST_32_]] : i32
// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>>
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32>
// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<32xf32, strided<[1], offset: ?>> to memref<32xf32>
// CHECK: [[VAR_2_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32>
// CHECK-DAG: [[VAR_reshape_:%.+]] = tensor.reshape [[VAR_2_]]([[VAR_cst_]]) : (tensor<32xf32>, tensor<2xi64>) -> tensor<1x32xf32>
// CHECK-DAG: [[VAR_expanded_:%.+]] = tensor.expand_shape [[VAR_2_]] {{.}}[0, 1]{{.}} output_shape [1, 32] : tensor<32xf32> into tensor<1x32xf32>
// CHECK-DAG: [[VAR_3_:%.+]] = tensor.empty() : tensor<64x32xf32>
// CHECK: [[VAR_4_:%.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins([[VAR_reshape_]] : tensor<1x32xf32>) outs([[VAR_3_]] : tensor<64x32xf32>) attrs = {broadcastDims = array<i64: 0>} {
// CHECK: [[VAR_4_:%.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins([[VAR_expanded_]] : tensor<1x32xf32>) outs([[VAR_3_]] : tensor<64x32xf32>) attrs = {broadcastDims = array<i64: 0>} {
// CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32):
// CHECK: linalg.yield [[IN_0_]] : f32
// CHECK: } -> tensor<64x32xf32>
// CHECK: [[VAR_5_:%.+]] = tensor.empty() : tensor<1xi64>
// CHECK: [[VAR_6_:%.+]] = linalg.fill ins([[CST_2048_]] : i64) outs([[VAR_5_]] : tensor<1xi64>) -> tensor<1xi64>
// CHECK-DAG: [[VAR_reshape_0_:%.+]] = tensor.reshape [[VAR_4_]]([[VAR_6_]]) : (tensor<64x32xf32>, tensor<1xi64>) -> tensor<2048xf32>
// CHECK-DAG: [[VAR_collapsed_:%.+]] = tensor.collapse_shape [[VAR_4_]] {{.}}[0, 1]{{.}} : tensor<64x32xf32> into tensor<2048xf32>
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [2048], strides: [1] : memref<*xf32> to memref<2048xf32, strided<[1], offset: ?>>
// CHECK: bufferization.materialize_in_destination [[VAR_reshape_0_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<2048xf32>, memref<2048xf32, strided<[1], offset: ?>>) -> ()
// CHECK: bufferization.materialize_in_destination [[VAR_collapsed_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<2048xf32>, memref<2048xf32, strided<[1], offset: ?>>) -> ()
// CHECK: return
// CHECK: }
19 changes: 7 additions & 12 deletions test/Conversion/TritonArithToLinalg/convert_tensor_reshape.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -55,36 +55,31 @@ module {
// CHECK: [[VAR_9_:%.+]] = linalg.fill ins([[VAR_0_]] : i32) outs([[VAR_8_]] : tensor<2048xi32>) -> tensor<2048xi32>
// CHECK: [[VAR_10_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_9_]], [[VAR_7_]] : tensor<2048xi32>, tensor<2048xi32>) outs([[VAR_9_]] : tensor<2048xi32>) {
// CHECK: ^bb0([[in_]]: i32, [[in_1_]]: i32, [[out_]]: i32):
// CHECK: [[VAR_22_3_:%.+]] = arith.addi [[in_]], [[in_]]_1 : i32
// CHECK: [[VAR_22_3_:%.+]] = arith.addi [[in_]], [[in_1_]] : i32
// CHECK: linalg.yield [[VAR_22_3_]] : i32
// CHECK: } -> tensor<2048xi32>
// CHECK: [[VAR_11_:%.+]] = tensor.empty() : tensor<32x!tt.ptr<f32>>
// CHECK: [[VAR_12_:%.+]] = linalg.fill ins([[PARAM_0_]] : !tt.ptr<f32>) outs([[VAR_11_]] : tensor<32x!tt.ptr<f32>>) -> tensor<32x!tt.ptr<f32>>
// CHECK: [[VAR_13_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_12_]], [[VAR_5_]] : tensor<32x!tt.ptr<f32>>, tensor<32xi32>) outs([[VAR_12_]] : tensor<32x!tt.ptr<f32>>) {
// CHECK: ^bb0([[in_]]: !tt.ptr<f32>, [[in_1_]]: i32, [[out_]]: !tt.ptr<f32>):
// CHECK: [[VAR_22_4_:%.+]] = tt.addptr [[in_]], [[in_]]_1 : !tt.ptr<f32>, i32
// CHECK: [[VAR_22_4_:%.+]] = tt.addptr [[in_]], [[in_1_]] : !tt.ptr<f32>, i32
// CHECK: linalg.yield [[VAR_22_4_]] : !tt.ptr<f32>
// CHECK: } -> tensor<32x!tt.ptr<f32>>
// CHECK-DAG: [[LOAD_VAR_13_MEM_:%.+]] = tt.load [[VAR_13_]] : tensor<32x!tt.ptr<f32>>
// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<[1, 32]> : tensor<2xi64>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_reshape_:%.+]] = tensor.reshape [[LOAD_VAR_13_MEM_]]([[VAR_cst_]]) : (tensor<32xf32>, tensor<2xi64>) -> tensor<1x32xf32>
// CHECK-DAG: [[VAR_expanded_:%.+]] = tensor.expand_shape [[LOAD_VAR_13_MEM_]] {{.}}[0, 1]{{.}} output_shape [1, 32] : tensor<32xf32> into tensor<1x32xf32>
// CHECK-DAG: [[VAR_15_:%.+]] = tensor.empty() : tensor<64x32xf32>
// CHECK: [[VAR_16_:%.+]] = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel"]} ins([[VAR_reshape_]] : tensor<1x32xf32>) outs([[VAR_15_]] : tensor<64x32xf32>) attrs = {broadcastDims = array<i64: 0>} {
// CHECK: [[VAR_16_:%.+]] = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel"]} ins([[VAR_expanded_]] : tensor<1x32xf32>) outs([[VAR_15_]] : tensor<64x32xf32>) attrs = {broadcastDims = array<i64: 0>} {
// CHECK: ^bb0([[in_]]: f32, [[out_]]: f32):
// CHECK: linalg.yield [[in_]] : f32
// CHECK: } -> tensor<64x32xf32>
// CHECK-DAG: [[CST_2048_:%.+]] = arith.constant 2048 : i64
// CHECK-DAG: [[VAR_17_:%.+]] = tensor.empty() : tensor<1xi64>
// CHECK: [[VAR_18_:%.+]] = linalg.fill ins([[CST_2048_]] : i64) outs([[VAR_17_]] : tensor<1xi64>) -> tensor<1xi64>
// CHECK-DAG: [[VAR_reshape_0_:%.+]] = tensor.reshape [[VAR_16_]]([[VAR_18_]]) : (tensor<64x32xf32>, tensor<1xi64>) -> tensor<2048xf32>
// CHECK-DAG: [[VAR_collapsed_:%.+]] = tensor.collapse_shape [[VAR_16_]] {{.}}[0, 1]{{.}} : tensor<64x32xf32> into tensor<2048xf32>
// CHECK-DAG: [[VAR_19_:%.+]] = tensor.empty() : tensor<2048x!tt.ptr<f32>>
// CHECK: [[VAR_20_:%.+]] = linalg.fill ins([[PARAM_1_]] : !tt.ptr<f32>) outs([[VAR_19_]] : tensor<2048x!tt.ptr<f32>>) -> tensor<2048x!tt.ptr<f32>>
// CHECK: [[VAR_21_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_20_]], [[VAR_10_]] : tensor<2048x!tt.ptr<f32>>, tensor<2048xi32>) outs([[VAR_20_]] : tensor<2048x!tt.ptr<f32>>) {
// CHECK: ^bb0([[in_]]: !tt.ptr<f32>, [[in_1_]]: i32, [[out_]]: !tt.ptr<f32>):
// CHECK: [[VAR_22_5_:%.+]] = tt.addptr [[in_]], [[in_]]_1 : !tt.ptr<f32>, i32
// CHECK: [[VAR_22_5_:%.+]] = tt.addptr [[in_]], [[in_1_]] : !tt.ptr<f32>, i32
// CHECK: linalg.yield [[VAR_22_5_]] : !tt.ptr<f32>
// CHECK: } -> tensor<2048x!tt.ptr<f32>>
// CHECK: tt.store [[VAR_21_]], [[VAR_reshape_0_]] : tensor<2048x!tt.ptr<f32>>
// CHECK: tt.store [[VAR_21_]], [[VAR_collapsed_]] : tensor<2048x!tt.ptr<f32>>
// CHECK: return
// CHECK: }
12 changes: 4 additions & 8 deletions test/Conversion/TritonToLinalg/convert_tensor_reshape.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,21 @@ module {


// CHECK-LABEL: func.func @bcast_kernel_01(
// CHECK: %[[C2048_I64:.*]] = arith.constant 2048 : i64
// CHECK: %[[CST:.*]] = arith.constant dense<[1, 32]> : tensor<2xi64>
// CHECK: %[[C32_I32:.*]] = arith.constant 32 : i32
// CHECK: %[[VAR_0:.*]] = arith.muli %arg5, %[[C32_I32]] : i32
// CHECK: %[[VAR_1:.*]] = arith.index_cast %[[VAR_0]] : i32 to index
// CHECK: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [%[[VAR_1]]], sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>>
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<32xf32>
// CHECK: memref.copy %[[REINTERPRET_CAST:.*]], %[[ALLOC]] : memref<32xf32, strided<[1], offset: ?>> to memref<32xf32>
// CHECK: %[[VAR_2:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<32xf32>
// CHECK: %[[RESHAPE:.*]] = tensor.reshape %[[VAR_2]](%[[CST]]) : (tensor<32xf32>, tensor<2xi64>) -> tensor<1x32xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[VAR_2]] {{.}}[0, 1]{{.}} output_shape [1, 32] : tensor<32xf32> into tensor<1x32xf32>
// CHECK: %[[VAR_3:.*]] = tensor.empty() : tensor<64x32xf32>
// CHECK: %[[VAR_4:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[RESHAPE]] : tensor<1x32xf32>) outs(%[[VAR_3:.*]] : tensor<64x32xf32>) attrs = {broadcastDims = array<i64: 0>} {
// CHECK: %[[VAR_4:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[EXPANDED]] : tensor<1x32xf32>) outs(%[[VAR_3:.*]] : tensor<64x32xf32>) attrs = {broadcastDims = array<i64: 0>} {
// CHECK: ^bb0(%in: f32, %out: f32):
// CHECK: linalg.yield %in : f32
// CHECK: } -> tensor<64x32xf32>
// CHECK: %[[VAR_5:.*]] = tensor.empty() : tensor<1xi64>
// CHECK: %[[VAR_6:.*]] = linalg.fill ins(%[[C2048_I64]] : i64) outs(%[[VAR_5]] : tensor<1xi64>) -> tensor<1xi64>
// CHECK: %[[RESHAPE_0:.*]] = tensor.reshape %[[VAR_4]](%[[VAR_6]]) : (tensor<64x32xf32>, tensor<1xi64>) -> tensor<2048xf32>
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[VAR_4]] {{.}}[0, 1]{{.}} : tensor<64x32xf32> into tensor<2048xf32>
// CHECK: %[[VAR_7:.*]] = arith.index_cast %[[VAR_0]] : i32 to index
// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %arg1 to offset: [%[[VAR_7]]], sizes: [2048], strides: [1] : memref<*xf32> to memref<2048xf32, strided<[1], offset: ?>>
// CHECK: bufferization.materialize_in_destination %[[RESHAPE_0]] in writable %[[REINTERPRET_CAST_1]] : (tensor<2048xf32>, memref<2048xf32, strided<[1], offset: ?>>) -> ()
// CHECK: bufferization.materialize_in_destination %[[COLLAPSED]] in writable %[[REINTERPRET_CAST_1]] : (tensor<2048xf32>, memref<2048xf32, strided<[1], offset: ?>>) -> ()
// CHECK: return

0 comments on commit 1657ad2

Please sign in to comment.