Skip to content

Commit

Permalink
Adjust the operand order of tt.dot to linalg.matmul converter (#191)
Browse files Browse the repository at this point in the history
The tt.dot with accmulator will lower to linalg.matmul and arith.add,
and the arith.add will further lower to linalg.generic, generic will
take the lhs of add as the DPS init, so the lhs of add must be the
accmulator.
  • Loading branch information
MercuryChen authored Dec 30, 2024
1 parent d5b7bee commit d9933bb
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1162,9 +1162,9 @@ struct MatmulConverter : public OpConversionPattern<triton::DotOp> {

if (!skipC) {
if (integers) {
res = rewriter.create<arith::AddIOp>(loc, res, opc);
res = rewriter.create<arith::AddIOp>(loc, opc, res);
} else {
res = rewriter.create<arith::AddFOp>(loc, res, opc);
res = rewriter.create<arith::AddFOp>(loc, opc, res);
}
}

Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/StructuredToMemref/dot.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ module {
// CHECK-DAG: [[VAR_4_:%.+]] = tensor.empty() : tensor<128x256xbf16>
// CHECK: [[VAR_5_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : bf16) outs([[VAR_4_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
// CHECK: [[VAR_6_:%.+]] = linalg.matmul ins([[VAR_0_]], [[VAR_transposed_]] : tensor<128x64xbf16>, tensor<64x256xbf16>) outs([[VAR_5_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_6_]], [[VAR_3_]] : tensor<128x256xbf16>, tensor<128x256xbf16>) outs([[VAR_6_]] : tensor<128x256xbf16>) {
// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_3_]], [[VAR_6_]] : tensor<128x256xbf16>, tensor<128x256xbf16>) outs([[VAR_3_]] : tensor<128x256xbf16>) {
// CHECK: ^bb0([[IN_0_:%.+]]: bf16, [[IN_1_:%.+]]: bf16, [[IN_2_:%.+]]: bf16):
// CHECK: [[VAR_8_:%.+]] = arith.addf [[IN_0_]], [[IN_1_]] : bf16
// CHECK: linalg.yield [[VAR_8_]] : bf16
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/TritonArithToLinalg/dot.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ module {
// CHECK-DAG: [[VAR_45_:%.+]] = tensor.empty() : tensor<128x256xbf16>
// CHECK: [[VAR_46_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : bf16) outs([[VAR_45_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
// CHECK: [[VAR_47_:%.+]] = linalg.matmul ins([[LOAD_VAR_34_MEM_]], [[VAR_transposed_]] : tensor<128x64xbf16>, tensor<64x256xbf16>) outs([[VAR_46_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
// CHECK: [[VAR_48_:%.+]] = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel"]} ins([[VAR_47_]], [[LOAD_VAR_43_MEM_]] : tensor<128x256xbf16>, tensor<128x256xbf16>) outs([[VAR_47_]] : tensor<128x256xbf16>) {
// CHECK: [[VAR_48_:%.+]] = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel"]} ins([[LOAD_VAR_43_MEM_]], [[VAR_47_]] : tensor<128x256xbf16>, tensor<128x256xbf16>) outs([[LOAD_VAR_43_MEM_]] : tensor<128x256xbf16>) {
// CHECK: ^bb0([[in_]]: bf16, [[in_1:.+]]: bf16, [[out_]]: bf16):
// CHECK: [[VAR_49_13_:%.+]] = arith.addf [[in_]], [[in_1]] : bf16
// CHECK: linalg.yield [[VAR_49_13_]] : bf16
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/TritonToLinalg/dot.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ module {
// CHECK-DAG: [[VAR_4_:%.+]] = tensor.empty() : tensor<128x256xbf16>
// CHECK: [[VAR_5_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : bf16) outs([[VAR_4_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
// CHECK: [[VAR_6_:%.+]] = linalg.matmul ins([[VAR_0_]], [[VAR_transposed_]] : tensor<128x64xbf16>, tensor<64x256xbf16>) outs([[VAR_5_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_6_]], [[VAR_3_]] : tensor<128x256xbf16>, tensor<128x256xbf16>) outs([[VAR_6_]] : tensor<128x256xbf16>) {
// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_3_]], [[VAR_6_]] : tensor<128x256xbf16>, tensor<128x256xbf16>) outs([[VAR_3_]] : tensor<128x256xbf16>) {
// CHECK: ^bb0([[in_:.+]]: bf16, [[in_1:.+]]: bf16, [[out_:.+]]: bf16):
// CHECK: [[VAR_8_:%.+]] = arith.addf [[in_]], [[in_1]] : bf16
// CHECK: linalg.yield [[VAR_8_]] : bf16
Expand Down

0 comments on commit d9933bb

Please sign in to comment.