diff --git a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp index 14ae50ba..dea2c4c1 100644 --- a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp +++ b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp @@ -993,6 +993,97 @@ struct PreciseDivConverter : public OpConversionPattern { } }; +struct CatConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto replacement = rewriter.create( + op.getLoc(), 0 /* concat dimension */, adaptor.getOperands()); + + rewriter.replaceOp(op, replacement); + + return success(); + } +}; + +struct SplitConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.getOperand(); + auto inputType = cast(input.getType()); + + Type resultType = op.getResults().front().getType(); + auto resultTensor = cast(resultType); + auto shape = inputType.getShape(); + + SmallVector offsets(shape.size(), rewriter.getIndexAttr(0)); + SmallVector strides(shape.size(), rewriter.getIndexAttr(1)); + SmallVector sizes = + llvm::to_vector(llvm::map_range(shape, [&](int64_t dim) -> OpFoldResult { + return rewriter.getIndexAttr(dim); + })); + + SmallVector results; + + for (int i = 0; i < 2; ++i) { + offsets.pop_back(); + sizes.pop_back(); + + offsets.push_back(rewriter.getIndexAttr(i)); + sizes.push_back(rewriter.getIndexAttr(1)); + Value slice = rewriter.create( + loc, resultTensor, input, offsets, sizes, strides); + results.push_back(slice); + } + + rewriter.replaceOp(op, results); + return success(); + } +}; + +struct JoinConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange inputs = op.getOperands(); + + auto resultType = cast(op.getResult().getType()); + + auto loc = op.getLoc(); + Value result = rewriter.create(loc, resultType.getShape(), resultType.getElementType()); + + auto shape = resultType.getShape(); + + SmallVector offsets(shape.size(), rewriter.getIndexAttr(0)); + SmallVector strides(shape.size(), rewriter.getIndexAttr(1)); + SmallVector sizes = + llvm::to_vector(llvm::map_range(shape, [&](int64_t dim) -> OpFoldResult { + return rewriter.getIndexAttr(dim); + })); + + for (int i = 0; i < 2; ++i) { + offsets.pop_back(); + sizes.pop_back(); + + offsets.push_back(rewriter.getIndexAttr(i)); + sizes.push_back(rewriter.getIndexAttr(1)); + result = rewriter.create(loc, inputs[i], result, offsets, sizes, strides); + } + + rewriter.replaceOp(op, result); + + return success(); + } +}; + struct MulHiUIOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; diff --git a/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp b/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp index 25cf9e6f..15fc4003 100644 --- a/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp +++ b/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp @@ -25,6 +25,7 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" #include "triton/Dialect/Triton/IR/Types.h" @@ -224,7 +225,6 @@ class StructuredToMemrefPass LogicalResult convertArgsToMemrefType() { auto moduleOp = getOperation(); - RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); TritonFunctionSignatureConverter typeConverter; diff --git a/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp b/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp index 1a24a38b..0a5a4628 100644 --- a/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp +++ b/lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp @@ -63,6 +63,9 @@ void mlir::triton::populateTritonArithToLinalgConversionPatterns( patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); diff --git a/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp b/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp index 0148620e..381d5690 100644 --- a/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp +++ b/lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -73,6 +74,19 @@ class TritonArithToLinalgPass } } + LogicalResult applyTensorConcatDecomposition() { + auto moduleOp = getOperation(); + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + + tensor::populateDecomposeTensorConcatPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) { + return failure(); + } + return success(); + } + public: void getDependentDialects(DialectRegistry ®istry) const override { registry @@ -161,6 +175,10 @@ class TritonArithToLinalgPass signalPassFailure(); } + if (failed(applyTensorConcatDecomposition())) { + signalPassFailure(); + } + // Convert tt.func and tt.return into func's counterparts if (ttToFuncFunc) { moduleOp.walk([&](triton::FuncOp func) { diff --git a/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp b/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp index 35630328..d85cab7a 100644 --- a/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp +++ b/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp @@ -55,6 +55,9 @@ void mlir::triton::populateTritonToLinalgConversionPatterns( patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); diff --git a/python/examples/conftest.py b/python/examples/conftest.py index 441f0033..634a04cb 100644 --- a/python/examples/conftest.py +++ b/python/examples/conftest.py @@ -20,7 +20,6 @@ def device(request): tests_not_supported = { "test_bin_op", - "test_split", "test_split_to_scalar", "test_interleave_scalars", "test_pointer_arguments", @@ -34,7 +33,6 @@ def device(request): "test_ptx_cast", "test_compare_op", "test_maxnreg", - "test_join", "test_join_scalars", "test_join_with_mma", "test_interleave", @@ -51,7 +49,6 @@ def device(request): "test_atomic_cas", "test_tensor_atomic_cas", "test_cast", - "test_cat", "test_store_constant", "test_reduce", "test_reduce1d", diff --git a/test/Conversion/TritonArithToLinalg/block_ptr_advance.mlir b/test/Conversion/TritonArithToLinalg/block_ptr_advance.mlir index 0ce15fc6..900762df 100644 --- a/test/Conversion/TritonArithToLinalg/block_ptr_advance.mlir +++ b/test/Conversion/TritonArithToLinalg/block_ptr_advance.mlir @@ -36,12 +36,12 @@ module { // CHECK-LABEL: func.func @matmul_kernel_with_block_pointers_01234567891011 // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: !tt.ptr, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32, [[PARAM_14_:%.+]]: i32, [[PARAM_15_:%.+]]: i32, [[PARAM_16_:%.+]]: i32, [[PARAM_17_:%.+]]: i32, [[PARAM_18_:%.+]]: i32, [[PARAM_19_:%.+]]: i32) { // CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : bf16 -// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<128x64xbf16> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : bf16) outs([[VAR_0_]] : tensor<128x64xbf16>) -> tensor<128x64xbf16> // CHECK-DAG: [[CST_64_:%.+]] = arith.constant 64 : i32 // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : i32 +// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<128x64xbf16> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : bf16) outs([[VAR_0_]] : tensor<128x64xbf16>) -> tensor<128x64xbf16> // CHECK-DAG: [[VAR_2_:%.+]] = arith.extsi [[PARAM_3_]] : i32 to i64 // CHECK-DAG: [[VAR_3_:%.+]] = arith.extsi [[PARAM_5_]] : i32 to i64 // CHECK-DAG: [[VAR_4_:%.+]] = arith.extsi [[PARAM_6_]] : i32 to i64 diff --git a/test/Conversion/TritonArithToLinalg/convert_addi_reduce.mlir b/test/Conversion/TritonArithToLinalg/convert_addi_reduce.mlir index e566528d..28da0804 100644 --- a/test/Conversion/TritonArithToLinalg/convert_addi_reduce.mlir +++ b/test/Conversion/TritonArithToLinalg/convert_addi_reduce.mlir @@ -19,9 +19,8 @@ module { // CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4096xi32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_0_]] : tensor<4096xi32>) -> tensor<4096xi32> -// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 // CHECK-DAG: [[VAR_2_:%.+]] = bufferization.alloc_tensor() : tensor -// CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_0_1_]] into [[VAR_2_]][] : tensor +// CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_0_]] into [[VAR_2_]][] : tensor // CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_1_]] : tensor<4096xi32>) outs([[VAR_inserted_]] : tensor) dimensions = [0] // CHECK: ([[in_:%.+]]: i32, [[in_]]it: i32) { // CHECK: [[VAR_3_:%.+]] = arith.addi [[in_]], [[in_]]it : i32 diff --git a/test/Conversion/TritonArithToLinalg/convert_argmin_argmax.mlir b/test/Conversion/TritonArithToLinalg/convert_argmin_argmax.mlir index 96986b51..cdcbd4e4 100644 --- a/test/Conversion/TritonArithToLinalg/convert_argmin_argmax.mlir +++ b/test/Conversion/TritonArithToLinalg/convert_argmin_argmax.mlir @@ -60,6 +60,8 @@ module { // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func.func @argmax_012 // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 // CHECK-DAG: [[VAR_0_:%.+]] = arith.muli [[PARAM_6_]], [[PARAM_2_]] : i32 // CHECK-DAG: [[VAR_1_:%.+]] = tensor.empty() : tensor<4096xi32> // CHECK: [[VAR_2_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_1_]] : tensor<4096xi32>) { @@ -71,20 +73,18 @@ module { // CHECK: [[VAR_3_:%.+]] = tensor.empty() : tensor<4096xi32> // CHECK: [[VAR_4_:%.+]] = linalg.fill ins([[VAR_0_]] : i32) outs([[VAR_3_]] : tensor<4096xi32>) -> tensor<4096xi32> // CHECK: [[VAR_5_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_4_]], [[VAR_2_]] : tensor<4096xi32>, tensor<4096xi32>) outs([[VAR_4_]] : tensor<4096xi32>) { -// CHECK: ^bb0([[in_:%.+]]: i32, [[in_]]_1: i32, [[out_]]: i32): -// CHECK: [[VAR_15_1_:%.+]] = arith.addi [[in_]], [[in_]]_1 : i32 +// CHECK: ^bb0([[in_:%.+]]: i32, [[in_1:%.+]]: i32, [[out_]]: i32): +// CHECK: [[VAR_15_1_:%.+]] = arith.addi [[in_]], [[in_1]] : i32 // CHECK: linalg.yield [[VAR_15_1_]] : i32 // CHECK: } -> tensor<4096xi32> // CHECK: [[VAR_6_:%.+]] = tensor.empty() : tensor<4096x!tt.ptr> // CHECK: [[VAR_7_:%.+]] = linalg.fill ins([[PARAM_0_]] : !tt.ptr) outs([[VAR_6_]] : tensor<4096x!tt.ptr>) -> tensor<4096x!tt.ptr> // CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]], [[VAR_5_]] : tensor<4096x!tt.ptr>, tensor<4096xi32>) outs([[VAR_7_]] : tensor<4096x!tt.ptr>) { -// CHECK: ^bb0([[in_]]: !tt.ptr, [[in_]]_1: i32, [[out_]]: !tt.ptr): -// CHECK: [[VAR_15_2_:%.+]] = tt.addptr [[in_]], [[in_]]_1 : !tt.ptr, i32 +// CHECK: ^bb0([[in_]]: !tt.ptr, [[in_1:%.+]]: i32, [[out_]]: !tt.ptr): +// CHECK: [[VAR_15_2_:%.+]] = tt.addptr [[in_]], [[in_1]] : !tt.ptr, i32 // CHECK: linalg.yield [[VAR_15_2_]] : !tt.ptr // CHECK: } -> tensor<4096x!tt.ptr> // CHECK-DAG: [[LOAD_VAR_8_MEM_:%.+]] = tt.load [[VAR_8_]] : tensor<4096x!tt.ptr> -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 -// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 // CHECK-DAG: [[VAR_10_:%.+]] = tensor.empty() : tensor // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_11_:%.+]] = linalg.fill ins([[CST_0_]] : f32) outs([[VAR_10_]] : tensor) -> tensor @@ -102,7 +102,6 @@ module { // CHECK-DAG: [[VAR_21_:%.+]] = arith.select [[VAR_19_]], [[in_1_]], [[init_2_]] : i32 // CHECK: linalg.yield [[VAR_20_]], [[VAR_21_]] : f32, i32 // CHECK: } -// CHECK-DAG: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]]#0[] : tensor // CHECK-DAG: [[VAR_extracted_0_:%.+]] = tensor.extract [[VAR_reduced_]]#1[] : tensor // CHECK-DAG: [[VAR_14_:%.+]] = tt.addptr [[PARAM_1_]], [[PARAM_6_]] : !tt.ptr, i32 // CHECK: tt.store [[VAR_14_]], [[VAR_extracted_0_]] : !tt.ptr @@ -112,6 +111,8 @@ module { // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func.func @argmin_012 // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0x7F800000 : f32 +// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 // CHECK-DAG: [[VAR_0_:%.+]] = arith.muli [[PARAM_6_]], [[PARAM_2_]] : i32 // CHECK-DAG: [[VAR_1_:%.+]] = tensor.empty() : tensor<4096xi32> // CHECK: [[VAR_2_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_1_]] : tensor<4096xi32>) { @@ -123,20 +124,18 @@ module { // CHECK: [[VAR_3_:%.+]] = tensor.empty() : tensor<4096xi32> // CHECK: [[VAR_4_:%.+]] = linalg.fill ins([[VAR_0_]] : i32) outs([[VAR_3_]] : tensor<4096xi32>) -> tensor<4096xi32> // CHECK: [[VAR_5_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_4_]], [[VAR_2_]] : tensor<4096xi32>, tensor<4096xi32>) outs([[VAR_4_]] : tensor<4096xi32>) { -// CHECK: ^bb0([[in_:%.+]]: i32, [[in_]]_1: i32, [[out_]]: i32): -// CHECK: [[VAR_15_1_:%.+]] = arith.addi [[in_]], [[in_]]_1 : i32 +// CHECK: ^bb0([[in_:%.+]]: i32, [[in_1:%.+]]: i32, [[out_]]: i32): +// CHECK: [[VAR_15_1_:%.+]] = arith.addi [[in_]], [[in_1]] : i32 // CHECK: linalg.yield [[VAR_15_1_]] : i32 // CHECK: } -> tensor<4096xi32> // CHECK: [[VAR_6_:%.+]] = tensor.empty() : tensor<4096x!tt.ptr> // CHECK: [[VAR_7_:%.+]] = linalg.fill ins([[PARAM_0_]] : !tt.ptr) outs([[VAR_6_]] : tensor<4096x!tt.ptr>) -> tensor<4096x!tt.ptr> // CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]], [[VAR_5_]] : tensor<4096x!tt.ptr>, tensor<4096xi32>) outs([[VAR_7_]] : tensor<4096x!tt.ptr>) { -// CHECK: ^bb0([[in_]]: !tt.ptr, [[in_]]_1: i32, [[out_]]: !tt.ptr): -// CHECK: [[VAR_15_2_:%.+]] = tt.addptr [[in_]], [[in_]]_1 : !tt.ptr, i32 +// CHECK: ^bb0([[in_]]: !tt.ptr, [[in_1:%.+]]: i32, [[out_]]: !tt.ptr): +// CHECK: [[VAR_15_2_:%.+]] = tt.addptr [[in_]], [[in_1]] : !tt.ptr, i32 // CHECK: linalg.yield [[VAR_15_2_]] : !tt.ptr // CHECK: } -> tensor<4096x!tt.ptr> // CHECK-DAG: [[LOAD_VAR_8_MEM_:%.+]] = tt.load [[VAR_8_]] : tensor<4096x!tt.ptr> -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0x7F800000 : f32 -// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 // CHECK-DAG: [[VAR_10_:%.+]] = tensor.empty() : tensor // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_11_:%.+]] = linalg.fill ins([[CST_0_]] : f32) outs([[VAR_10_]] : tensor) -> tensor @@ -154,7 +153,6 @@ module { // CHECK-DAG: [[VAR_21_:%.+]] = arith.select [[VAR_19_]], [[in_1_]], [[init_2_]] : i32 // CHECK: linalg.yield [[VAR_20_]], [[VAR_21_]] : f32, i32 // CHECK: } -// CHECK-DAG: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]]#0[] : tensor // CHECK-DAG: [[VAR_extracted_0_:%.+]] = tensor.extract [[VAR_reduced_]]#1[] : tensor // CHECK-DAG: [[VAR_14_:%.+]] = tt.addptr [[PARAM_1_]], [[PARAM_6_]] : !tt.ptr, i32 // CHECK: tt.store [[VAR_14_]], [[VAR_extracted_0_]] : !tt.ptr diff --git a/test/Conversion/TritonArithToLinalg/convert_argmin_argmax_2d.mlir b/test/Conversion/TritonArithToLinalg/convert_argmin_argmax_2d.mlir index 61341aec..a04ceb6c 100644 --- a/test/Conversion/TritonArithToLinalg/convert_argmin_argmax_2d.mlir +++ b/test/Conversion/TritonArithToLinalg/convert_argmin_argmax_2d.mlir @@ -113,6 +113,8 @@ module { // CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0, d1) -> (0, d1)> // CHECK-LABEL: func.func @test_argmax // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 // CHECK: [[VAR_0_:%.+]] = tensor.empty() : tensor<4xi32> // CHECK: [[VAR_1_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_0_]] : tensor<4xi32>) { // CHECK: ^bb0([[out_:%.+]]: i32): @@ -164,9 +166,7 @@ module { // CHECK: ^bb0([[in_]]: i32, [[out_]]: i32): // CHECK: linalg.yield [[in_]] : i32 // CHECK: } -> tensor<4x4xi32> -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 -// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 -// CHECK-DAG: [[VAR_19_:%.+]] = tensor.empty() : tensor<4xf32> +// CHECK: [[VAR_19_:%.+]] = tensor.empty() : tensor<4xf32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_20_:%.+]] = linalg.fill ins([[CST_0_]] : f32) outs([[VAR_19_]] : tensor<4xf32>) -> tensor<4xf32> // CHECK-DAG: [[VAR_21_:%.+]] = tensor.empty() : tensor<4xi32> @@ -205,6 +205,8 @@ module { // CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0, d1) -> (0, d1)> // CHECK-LABEL: func.func @test_argmin // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0x7F800000 : f32 +// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 // CHECK: [[VAR_0_:%.+]] = tensor.empty() : tensor<4xi32> // CHECK: [[VAR_1_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_0_]] : tensor<4xi32>) { // CHECK: ^bb0([[out_]]: i32): @@ -256,9 +258,7 @@ module { // CHECK: ^bb0([[in_]]: i32, [[out_]]: i32): // CHECK: linalg.yield [[in_]] : i32 // CHECK: } -> tensor<4x4xi32> -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0x7F800000 : f32 -// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 -// CHECK-DAG: [[VAR_19_:%.+]] = tensor.empty() : tensor<4xf32> +// CHECK: [[VAR_19_:%.+]] = tensor.empty() : tensor<4xf32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_20_:%.+]] = linalg.fill ins([[CST_0_]] : f32) outs([[VAR_19_]] : tensor<4xf32>) -> tensor<4xf32> // CHECK-DAG: [[VAR_21_:%.+]] = tensor.empty() : tensor<4xi32> diff --git a/test/Conversion/TritonArithToLinalg/convert_minmax_fp_reduce.mlir b/test/Conversion/TritonArithToLinalg/convert_minmax_fp_reduce.mlir index 53de3449..6a72dd74 100644 --- a/test/Conversion/TritonArithToLinalg/convert_minmax_fp_reduce.mlir +++ b/test/Conversion/TritonArithToLinalg/convert_minmax_fp_reduce.mlir @@ -33,10 +33,10 @@ module { // CHECK-LABEL: func.func @maxnumf // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) { // CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 // CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4096xf32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[VAR_0_]] : tensor<4096xf32>) -> tensor<4096xf32> -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 // CHECK-DAG: [[VAR_2_:%.+]] = bufferization.alloc_tensor() : tensor // CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_0_]] into [[VAR_2_]][] : tensor // CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_1_]] : tensor<4096xf32>) outs([[VAR_inserted_]] : tensor) dimensions = [0] @@ -51,10 +51,10 @@ module { // CHECK-LABEL: func.func @minnumf // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) { // CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0x7F800000 : f32 // CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4096xf32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[VAR_0_]] : tensor<4096xf32>) -> tensor<4096xf32> -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0x7F800000 : f32 // CHECK-DAG: [[VAR_2_:%.+]] = bufferization.alloc_tensor() : tensor // CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_0_]] into [[VAR_2_]][] : tensor // CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_1_]] : tensor<4096xf32>) outs([[VAR_inserted_]] : tensor) dimensions = [0] diff --git a/test/Conversion/TritonArithToLinalg/convert_minmax_reduce.mlir b/test/Conversion/TritonArithToLinalg/convert_minmax_reduce.mlir index 4f43f5fe..5712ef9e 100644 --- a/test/Conversion/TritonArithToLinalg/convert_minmax_reduce.mlir +++ b/test/Conversion/TritonArithToLinalg/convert_minmax_reduce.mlir @@ -66,11 +66,11 @@ module { // CHECK-LABEL: func.func @minmax_sgt // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) { +// CHECK-DAG: [[CST_minus_2147483648_:%.+]] = arith.constant -2147483648 : i32 // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 // CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4096xi32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_0_]] : tensor<4096xi32>) -> tensor<4096xi32> -// CHECK-DAG: [[CST_minus_2147483648_:%.+]] = arith.constant -2147483648 : i32 // CHECK-DAG: [[VAR_2_:%.+]] = bufferization.alloc_tensor() : tensor // CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_minus_2147483648_]] into [[VAR_2_]][] : tensor // CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_1_]] : tensor<4096xi32>) outs([[VAR_inserted_]] : tensor) dimensions = [0] @@ -88,9 +88,8 @@ module { // CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4096xi32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_0_]] : tensor<4096xi32>) -> tensor<4096xi32> -// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 // CHECK-DAG: [[VAR_2_:%.+]] = bufferization.alloc_tensor() : tensor -// CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_0_1_]] into [[VAR_2_]][] : tensor +// CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_0_]] into [[VAR_2_]][] : tensor // CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_1_]] : tensor<4096xi32>) outs([[VAR_inserted_]] : tensor) dimensions = [0] // CHECK: ([[in_:%.+]]: i32, [[in_]]it: i32) { // CHECK: [[VAR_3_:%.+]] = arith.maxui [[in_]], [[in_]]it : i32 @@ -103,10 +102,10 @@ module { // CHECK-LABEL: func.func @minmax_slt // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) { // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_2147483647_:%.+]] = arith.constant 2147483647 : i32 // CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4096xi32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_0_]] : tensor<4096xi32>) -> tensor<4096xi32> -// CHECK-DAG: [[CST_2147483647_:%.+]] = arith.constant 2147483647 : i32 // CHECK-DAG: [[VAR_2_:%.+]] = bufferization.alloc_tensor() : tensor // CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_2147483647_]] into [[VAR_2_]][] : tensor // CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_1_]] : tensor<4096xi32>) outs([[VAR_inserted_]] : tensor) dimensions = [0] @@ -121,10 +120,10 @@ module { // CHECK-LABEL: func.func @minmax_ult // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) { // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 // CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4096xi32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_0_]] : tensor<4096xi32>) -> tensor<4096xi32> -// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 // CHECK-DAG: [[VAR_2_:%.+]] = bufferization.alloc_tensor() : tensor // CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_minus_1_]] into [[VAR_2_]][] : tensor // CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_1_]] : tensor<4096xi32>) outs([[VAR_inserted_]] : tensor) dimensions = [0] diff --git a/test/Conversion/TritonArithToLinalg/convert_tensor_reshape.mlir b/test/Conversion/TritonArithToLinalg/convert_tensor_reshape.mlir index aa67356a..a9f61431 100644 --- a/test/Conversion/TritonArithToLinalg/convert_tensor_reshape.mlir +++ b/test/Conversion/TritonArithToLinalg/convert_tensor_reshape.mlir @@ -28,7 +28,7 @@ module { // CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func.func @bcast_kernel_01 // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { -// CHECK: [[CST_32_:%.+]] = arith.constant 32 : i32 +// CHECK-DAG: [[CST_32_:%.+]] = arith.constant 32 : i32 // CHECK-DAG: [[VAR_0_:%.+]] = arith.muli [[PARAM_5_]], [[CST_32_]] : i32 // CHECK-DAG: [[VAR_1_:%.+]] = tensor.empty() : tensor<32xi32> // CHECK: [[VAR_2_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_1_]] : tensor<32xi32>) { diff --git a/test/Conversion/TritonArithToLinalg/dot.mlir b/test/Conversion/TritonArithToLinalg/dot.mlir index e179c143..e161a12b 100644 --- a/test/Conversion/TritonArithToLinalg/dot.mlir +++ b/test/Conversion/TritonArithToLinalg/dot.mlir @@ -55,15 +55,15 @@ module { // CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0, d1) -> (0, d1)> // CHECK-LABEL: func.func @kernel // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: !tt.ptr, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : bf16 // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : i32 +// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : i32 // CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<128xi32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_256_]] : i32) outs([[VAR_0_]] : tensor<128xi32>) -> tensor<128xi32> -// CHECK-DAG: [[CST_256_1_:%.+]] = arith.constant 256 : i32 // CHECK-DAG: [[VAR_2_:%.+]] = tensor.empty() : tensor<64xi32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_3_:%.+]] = linalg.fill ins([[CST_256_1_]] : i32) outs([[VAR_2_]] : tensor<64xi32>) -> tensor<64xi32> -// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : i32 +// CHECK-DAG: [[VAR_3_:%.+]] = linalg.fill ins([[CST_256_]] : i32) outs([[VAR_2_]] : tensor<64xi32>) -> tensor<64xi32> // CHECK-DAG: [[VAR_4_:%.+]] = tensor.empty() : tensor<128xi32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_5_:%.+]] = linalg.fill ins([[CST_128_]] : i32) outs([[VAR_4_]] : tensor<128xi32>) -> tensor<128xi32> @@ -165,7 +165,7 @@ module { // CHECK: [[VAR_33_:%.+]] = linalg.fill ins([[PARAM_0_]] : !tt.ptr) outs([[VAR_32_]] : tensor<128x64x!tt.ptr>) -> tensor<128x64x!tt.ptr> // CHECK: [[VAR_34_:%.+]] = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel"]} ins([[VAR_33_]], [[VAR_15_]] : tensor<128x64x!tt.ptr>, tensor<128x64xi32>) outs([[VAR_33_]] : tensor<128x64x!tt.ptr>) { // CHECK: ^bb0([[in_]]: !tt.ptr, [[in_1:.+]]: i32, [[out_]]: !tt.ptr): -// CHECK: [[VAR_49_10_:%.+]] = tt.addptr [[in_]], [[in_]]_6 : !tt.ptr, i32 +// CHECK: [[VAR_49_10_:%.+]] = tt.addptr [[in_]], [[in_1]] : !tt.ptr, i32 // CHECK: linalg.yield [[VAR_49_10_]] : !tt.ptr // CHECK: } -> tensor<128x64x!tt.ptr> // CHECK-DAG: [[LOAD_VAR_34_MEM_:%.+]] = tt.load [[VAR_34_]] : tensor<128x64x!tt.ptr> @@ -173,7 +173,7 @@ module { // CHECK: [[VAR_37_:%.+]] = linalg.fill ins([[PARAM_1_]] : !tt.ptr) outs([[VAR_36_]] : tensor<256x64x!tt.ptr>) -> tensor<256x64x!tt.ptr> // CHECK: [[VAR_38_:%.+]] = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel"]} ins([[VAR_37_]], [[VAR_25_]] : tensor<256x64x!tt.ptr>, tensor<256x64xi32>) outs([[VAR_37_]] : tensor<256x64x!tt.ptr>) { // CHECK: ^bb0([[in_]]: !tt.ptr, [[in_1:.+]]: i32, [[out_]]: !tt.ptr): -// CHECK: [[VAR_49_11_:%.+]] = tt.addptr [[in_]], [[in_]]_6 : !tt.ptr, i32 +// CHECK: [[VAR_49_11_:%.+]] = tt.addptr [[in_]], [[in_1]] : !tt.ptr, i32 // CHECK: linalg.yield [[VAR_49_11_]] : !tt.ptr // CHECK: } -> tensor<256x64x!tt.ptr> // CHECK-DAG: [[LOAD_VAR_38_MEM_:%.+]] = tt.load [[VAR_38_]] : tensor<256x64x!tt.ptr> @@ -184,17 +184,16 @@ module { // CHECK: [[VAR_42_:%.+]] = linalg.fill ins([[PARAM_2_]] : !tt.ptr) outs([[VAR_41_]] : tensor<128x256x!tt.ptr>) -> tensor<128x256x!tt.ptr> // CHECK: [[VAR_43_:%.+]] = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel"]} ins([[VAR_42_]], [[VAR_31_]] : tensor<128x256x!tt.ptr>, tensor<128x256xi32>) outs([[VAR_42_]] : tensor<128x256x!tt.ptr>) { // CHECK: ^bb0([[in_]]: !tt.ptr, [[in_1:.+]]: i32, [[out_]]: !tt.ptr): -// CHECK: [[VAR_49_12_:%.+]] = tt.addptr [[in_]], [[in_]]_6 : !tt.ptr, i32 +// CHECK: [[VAR_49_12_:%.+]] = tt.addptr [[in_]], [[in_1]] : !tt.ptr, i32 // CHECK: linalg.yield [[VAR_49_12_]] : !tt.ptr // CHECK: } -> tensor<128x256x!tt.ptr> // CHECK-DAG: [[LOAD_VAR_43_MEM_:%.+]] = tt.load [[VAR_43_]] : tensor<128x256x!tt.ptr> // CHECK-DAG: [[VAR_45_:%.+]] = tensor.empty() : tensor<128x256xbf16> -// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : bf16 // CHECK: [[VAR_46_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : bf16) outs([[VAR_45_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> // CHECK: [[VAR_47_:%.+]] = linalg.matmul ins([[LOAD_VAR_34_MEM_]], [[VAR_transposed_]] : tensor<128x64xbf16>, tensor<64x256xbf16>) outs([[VAR_46_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> // CHECK: [[VAR_48_:%.+]] = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel"]} ins([[VAR_47_]], [[LOAD_VAR_43_MEM_]] : tensor<128x256xbf16>, tensor<128x256xbf16>) outs([[VAR_47_]] : tensor<128x256xbf16>) { -// CHECK: ^bb0([[in_]]: bf16, [[in_]]_6: bf16, [[out_]]: bf16): -// CHECK: [[VAR_49_13_:%.+]] = arith.addf [[in_]], [[in_]]_6 : bf16 +// CHECK: ^bb0([[in_]]: bf16, [[in_1:.+]]: bf16, [[out_]]: bf16): +// CHECK: [[VAR_49_13_:%.+]] = arith.addf [[in_]], [[in_1]] : bf16 // CHECK: linalg.yield [[VAR_49_13_]] : bf16 // CHECK: } -> tensor<128x256xbf16> // CHECK: tt.store [[VAR_43_]], [[VAR_48_]] : tensor<128x256x!tt.ptr> diff --git a/test/Conversion/TritonArithToLinalg/join.mlir b/test/Conversion/TritonArithToLinalg/join.mlir new file mode 100644 index 00000000..7eba54dd --- /dev/null +++ b/test/Conversion/TritonArithToLinalg/join.mlir @@ -0,0 +1,95 @@ +// RUN: triton-shared-opt --triton-arith-to-linalg %s | FileCheck %s + +module { + tt.func public @kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<2> : tensor<128x1xi32> + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<128x!tt.ptr>, tensor<128xi32> + %3 = tt.load %2 : tensor<128x!tt.ptr> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<128x!tt.ptr>, tensor<128xi32> + %6 = tt.load %5 : tensor<128x!tt.ptr> + %7 = tt.join %3, %6 : tensor<128xi32> -> tensor<128x2xi32> + %8 = tt.expand_dims %0 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %9 = arith.muli %8, %cst : tensor<128x1xi32> + %10 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr> + %11 = tt.addptr %10, %9 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> + %12 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %13 = tt.expand_dims %12 {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> + %14 = tt.broadcast %11 : tensor<128x1x!tt.ptr> -> tensor<128x2x!tt.ptr> + %15 = tt.broadcast %13 : tensor<1x2xi32> -> tensor<128x2xi32> + %16 = tt.addptr %14, %15 : tensor<128x2x!tt.ptr>, tensor<128x2xi32> + tt.store %16, %7 : tensor<128x2x!tt.ptr> + tt.return + } +} + +// CHECK: func.func @kernel(%arg0: !tt.ptr {{.*}}, %arg1: !tt.ptr {{.*}}, %arg2: !tt.ptr {{.*}}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { +// CHECK: [[C2_I32:%.+]] = arith.constant 2 : i32 +// CHECK: [[EMPTY_128X1_I32:%.+]] = tensor.empty() : tensor<128x1xi32> +// CHECK: [[FILLED_C2:%.+]] = linalg.fill ins([[C2_I32]] : i32) outs([[EMPTY_128X1_I32]] : tensor<128x1xi32>) -> tensor<128x1xi32> +// CHECK: [[EMPTY_128_I32:%.+]] = tensor.empty() : tensor<128xi32> +// CHECK: [[RANGE_128:%.+]] = linalg.generic {{.*}} outs([[EMPTY_128_I32]] : tensor<128xi32>) { +// CHECK: ^bb0(%out: i32): +// CHECK: [[IDX0:%.+]] = linalg.index 0 : index +// CHECK: [[I32_IDX0:%.+]] = arith.index_cast [[IDX0]] : index to i32 +// CHECK: linalg.yield [[I32_IDX0]] : i32 +// CHECK: } -> tensor<128xi32> +// CHECK: [[EMPTY_PTR128:%.+]] = tensor.empty() : tensor<128x!tt.ptr> +// CHECK: [[SPLAT_ARG0:%.+]] = linalg.fill ins(%arg0 : !tt.ptr) outs([[EMPTY_PTR128]] : tensor<128x!tt.ptr>) -> tensor<128x!tt.ptr> +// CHECK: [[ADDPTR_ARG0:%.+]] = linalg.generic {{.*}} ins([[SPLAT_ARG0]], [[RANGE_128]] : tensor<128x!tt.ptr>, tensor<128xi32>) outs([[SPLAT_ARG0]] : tensor<128x!tt.ptr>) { +// CHECK: ^bb0([[IN_PTR:%.+]]: !tt.ptr, [[IN_I32:%.+]]: i32, [[OUT_PTR:%.+]]: !tt.ptr): +// CHECK: [[NEW_PTR0:%.+]] = tt.addptr [[IN_PTR]], [[IN_I32]] : !tt.ptr, i32 +// CHECK: linalg.yield [[NEW_PTR0]] : !tt.ptr +// CHECK: } -> tensor<128x!tt.ptr> +// CHECK: [[LOADED_ARG0:%.+]] = tt.load [[ADDPTR_ARG0]] : tensor<128x!tt.ptr> +// CHECK: [[EMPTY_PTR128_1:%.+]] = tensor.empty() : tensor<128x!tt.ptr> +// CHECK: [[SPLAT_ARG1:%.+]] = linalg.fill ins(%arg1 : !tt.ptr) outs([[EMPTY_PTR128_1]] : tensor<128x!tt.ptr>) -> tensor<128x!tt.ptr> +// CHECK: [[ADDPTR_ARG1:%.+]] = linalg.generic {{.*}} ins([[SPLAT_ARG1]], [[RANGE_128]] : tensor<128x!tt.ptr>, tensor<128xi32>) outs([[SPLAT_ARG1]] : tensor<128x!tt.ptr>) { +// CHECK: ^bb0([[IN_PTR:%.+]]: !tt.ptr, [[IN_I32:%.+]]: i32, [[OUT_PTR:%.+]]: !tt.ptr): +// CHECK: [[NEW_PTR1:%.+]] = tt.addptr [[IN_PTR]], [[IN_I32]] : !tt.ptr, i32 +// CHECK: linalg.yield [[NEW_PTR1]] : !tt.ptr +// CHECK: } -> tensor<128x!tt.ptr> +// CHECK: [[LOADED_ARG1:%.+]] = tt.load [[ADDPTR_ARG1]] : tensor<128x!tt.ptr> +// CHECK: [[EMPTY_JOIN:%.+]] = tensor.empty() : tensor<128x2xi32> +// CHECK: [[INSERTED_SLICE0:%.+]] = tensor.insert_slice [[LOADED_ARG0]] into [[EMPTY_JOIN]]{{\[}}0, 0{{\]}} [128, 1] [1, 1] : tensor<128xi32> into tensor<128x2xi32> +// CHECK: [[INSERTED_SLICE1:%.+]] = tensor.insert_slice [[LOADED_ARG1]] into [[INSERTED_SLICE0]]{{\[}}0, 1{{\]}} [128, 1] [1, 1] : tensor<128xi32> into tensor<128x2xi32> +// CHECK: [[EXPANDED_RANGE:%.+]] = tensor.expand_shape +// CHECK: [[MULI_RESULT:%.+]] = linalg.generic {{.*}} ins([[EXPANDED_RANGE]], [[FILLED_C2]] : tensor<128x1xi32>, tensor<128x1xi32>) outs([[EXPANDED_RANGE:%.+]] : tensor<128x1xi32>) { +// CHECK: ^bb0([[IN_I32_0:%.+]]: i32, [[IN_I32_1:%.+]]: i32, %out: i32): +// CHECK: [[MUL_RESULT:%.+]] = arith.muli [[IN_I32_0]], [[IN_I32_1]] : i32 +// CHECK: linalg.yield [[MUL_RESULT]] : i32 +// CHECK: } -> tensor<128x1xi32> +// CHECK: [[EMPTY_PTR128X1:%.+]] = tensor.empty() : tensor<128x1x!tt.ptr> +// CHECK: [[SPLAT_ARG2:%.+]] = linalg.fill ins(%arg2 : !tt.ptr) outs([[EMPTY_PTR128X1]] : tensor<128x1x!tt.ptr>) -> tensor<128x1x!tt.ptr> +// CHECK: [[ADDPTR_ARG2:%.+]] = linalg.generic {{.*}} ins([[SPLAT_ARG2]], [[MULI_RESULT]] : tensor<128x1x!tt.ptr>, tensor<128x1xi32>) outs([[SPLAT_ARG2]] : tensor<128x1x!tt.ptr>) { +// CHECK: ^bb0([[IN_PTR:%.+]]: !tt.ptr, [[IN_I32:%.+]]: i32, [[OUT_PTR:%.+]]: !tt.ptr): +// CHECK: [[NEW_PTR2:%.+]] = tt.addptr [[IN_PTR]], [[IN_I32]] : !tt.ptr, i32 +// CHECK: linalg.yield [[NEW_PTR2]] : !tt.ptr +// CHECK: } -> tensor<128x1x!tt.ptr> +// CHECK: [[EMPTY_RANGE2:%.+]] = tensor.empty() : tensor<2xi32> +// CHECK: [[RANGE_2:%.+]] = linalg.generic {{.*}} outs([[EMPTY_RANGE2]] : tensor<2xi32>) { +// CHECK: ^bb0(%out: i32): +// CHECK: [[IDX1:%.+]] = linalg.index 0 : index +// CHECK: [[I32_IDX1:%.+]] = arith.index_cast [[IDX1]] : index to i32 +// CHECK: linalg.yield [[I32_IDX1]] : i32 +// CHECK: } -> tensor<2xi32> +// CHECK: [[EXPANDED_RANGE2:%.+]] = tensor.expand_shape [[RANGE_2]] +// CHECK: [[EMPTY_PTR128X2:%.+]] = tensor.empty() : tensor<128x2x!tt.ptr> +// CHECK: [[BROADCASTED_PTR:%.+]] = linalg.generic {{.*}} ins([[ADDPTR_ARG2]] : tensor<128x1x!tt.ptr>) outs([[EMPTY_PTR128X2]] : tensor<128x2x!tt.ptr>) attrs = {{.*}} { +// CHECK: ^bb0([[IN_PTR:%.+]]: !tt.ptr, [[OUT_PTR:%.+]]: !tt.ptr): +// CHECK: linalg.yield [[IN_PTR]] : !tt.ptr +// CHECK: } -> tensor<128x2x!tt.ptr> +// CHECK: [[EMPTY_I32_128X2:%.+]] = tensor.empty() : tensor<128x2xi32> +// CHECK: [[BROADCASTED_I32:%.+]] = linalg.generic {{.*}} ins([[EXPANDED_RANGE2]] : tensor<1x2xi32>) outs([[EMPTY_I32_128X2]] : tensor<128x2xi32>) attrs = {{.*}} { +// CHECK: ^bb0([[IN_I32:%.+]]: i32, [[OUT_I32:%.+]]: i32): +// CHECK: linalg.yield [[IN_I32]] : i32 +// CHECK: } -> tensor<128x2xi32> +// CHECK: [[ADDPTR_FINAL:%.+]] = linalg.generic {{.*}} ins([[BROADCASTED_PTR]], [[BROADCASTED_I32]] : tensor<128x2x!tt.ptr>, tensor<128x2xi32>) outs([[BROADCASTED_PTR]] : tensor<128x2x!tt.ptr>) { +// CHECK: ^bb0([[IN_PTR:%.+]]: !tt.ptr, [[IN_I32:%.+]]: i32, [[OUT_PTR:%.+]]: !tt.ptr): +// CHECK: [[FINAL_PTR:%.+]] = tt.addptr [[IN_PTR]], [[IN_I32]] : !tt.ptr, i32 +// CHECK: linalg.yield [[FINAL_PTR]] : !tt.ptr +// CHECK: } -> tensor<128x2x!tt.ptr> +// CHECK: tt.store [[ADDPTR_FINAL]], [[INSERTED_SLICE1]] : tensor<128x2x!tt.ptr> diff --git a/test/Conversion/TritonArithToLinalg/reduce_extend_fp32_precision.mlir b/test/Conversion/TritonArithToLinalg/reduce_extend_fp32_precision.mlir index 31c4679f..0f95f19f 100644 --- a/test/Conversion/TritonArithToLinalg/reduce_extend_fp32_precision.mlir +++ b/test/Conversion/TritonArithToLinalg/reduce_extend_fp32_precision.mlir @@ -80,7 +80,7 @@ module { // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func.func @fn1 // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { -// CHECK: [[VAR_0_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [32], strides: [1], offsets: [0], shape: [0], order: [] : to tensor<32x!tt.ptr> +// CHECK-DAG: [[VAR_0_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [32], strides: [1], offsets: [0], shape: [0], order: [] : to tensor<32x!tt.ptr> // CHECK-DAG: [[VAR_1_:%.+]] = "tts.load"([[VAR_0_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<32x!tt.ptr>) -> tensor<32xf16> // CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[VAR_2_:%.+]] = bufferization.alloc_tensor() : tensor @@ -120,7 +120,7 @@ module { // // CHECK-LABEL: func.func @fn2 // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { -// CHECK: [[VAR_0_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [32], strides: [1], offsets: [0], shape: [0], order: [] : to tensor<32x!tt.ptr> +// CHECK-DAG: [[VAR_0_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [32], strides: [1], offsets: [0], shape: [0], order: [] : to tensor<32x!tt.ptr> // CHECK-DAG: [[VAR_1_:%.+]] = "tts.load"([[VAR_0_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<32x!tt.ptr>) -> tensor<32xf16> // CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[VAR_2_:%.+]] = bufferization.alloc_tensor() : tensor @@ -147,7 +147,7 @@ module { // // CHECK-LABEL: func.func @fn3 // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { -// CHECK: [[VAR_0_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [32], strides: [1], offsets: [0], shape: [0], order: [] : to tensor<32x!tt.ptr> +// CHECK-DAG: [[VAR_0_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [32], strides: [1], offsets: [0], shape: [0], order: [] : to tensor<32x!tt.ptr> // CHECK-DAG: [[VAR_1_:%.+]] = "tts.load"([[VAR_0_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<32x!tt.ptr>) -> tensor<32xbf16> // CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[VAR_2_:%.+]] = bufferization.alloc_tensor() : tensor @@ -174,7 +174,7 @@ module { // // CHECK-LABEL: func.func @fn4 // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { -// CHECK: [[VAR_0_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [32], strides: [1], offsets: [0], shape: [0], order: [] : to tensor<32x!tt.ptr> +// CHECK-DAG: [[VAR_0_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [32], strides: [1], offsets: [0], shape: [0], order: [] : to tensor<32x!tt.ptr> // CHECK-DAG: [[VAR_1_:%.+]] = "tts.load"([[VAR_0_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<32x!tt.ptr>) -> tensor<32xf32> // CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[VAR_2_:%.+]] = bufferization.alloc_tensor() : tensor diff --git a/test/Conversion/TritonArithToLinalg/reducemax_32_256_bf16.mlir b/test/Conversion/TritonArithToLinalg/reducemax_32_256_bf16.mlir index b9a01659..b7aadc15 100644 --- a/test/Conversion/TritonArithToLinalg/reducemax_32_256_bf16.mlir +++ b/test/Conversion/TritonArithToLinalg/reducemax_32_256_bf16.mlir @@ -47,6 +47,7 @@ module { // CHECK-DAG: [[MAP_6_:#.+]] = affine_map<(d0, d1, d2) -> (0, d1, d2)> // CHECK-LABEL: func.func @kernel // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: tensor<256x16x!tt.ptr>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF80 : bf16 // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : i32 // CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<32xi32> // CHECK-NOT: separator of consecutive DAGs @@ -131,7 +132,6 @@ module { // CHECK: linalg.yield [[VAR_29_6_]] : !tt.ptr // CHECK: } -> tensor<32x256x16x!tt.ptr> // CHECK-DAG: [[LOAD_VAR_25_MEM_:%.+]] = tt.load [[VAR_25_]] : tensor<32x256x16x!tt.ptr> -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF80 : bf16 // CHECK-DAG: [[VAR_27_:%.+]] = tensor.empty() : tensor<256x16xbf16> // CHECK: [[VAR_28_:%.+]] = linalg.fill ins([[CST_0_]] : bf16) outs([[VAR_27_]] : tensor<256x16xbf16>) -> tensor<256x16xbf16> // CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[LOAD_VAR_25_MEM_]] : tensor<32x256x16xbf16>) outs([[VAR_28_]] : tensor<256x16xbf16>) dimensions = [0] diff --git a/test/Conversion/TritonArithToLinalg/reducesum_512_256_bf16_axis0.mlir b/test/Conversion/TritonArithToLinalg/reducesum_512_256_bf16_axis0.mlir index 438d757a..ee890109 100644 --- a/test/Conversion/TritonArithToLinalg/reducesum_512_256_bf16_axis0.mlir +++ b/test/Conversion/TritonArithToLinalg/reducesum_512_256_bf16_axis0.mlir @@ -36,6 +36,7 @@ module { // CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0, d1) -> (0, d1)> // CHECK-LABEL: func.func @kernel // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : bf16 // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : i32 // CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<512xi32> // CHECK-NOT: separator of consecutive DAGs @@ -91,7 +92,6 @@ module { // CHECK: linalg.yield [[VAR_21_5_]] : !tt.ptr // CHECK: } -> tensor<256x!tt.ptr> // CHECK-DAG: [[LOAD_VAR_14_MEM_:%.+]] = tt.load [[VAR_14_]] : tensor<512x256x!tt.ptr> -// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : bf16 // CHECK-DAG: [[VAR_19_:%.+]] = tensor.empty() : tensor<256xbf16> // CHECK: [[VAR_20_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : bf16) outs([[VAR_19_]] : tensor<256xbf16>) -> tensor<256xbf16> // CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[LOAD_VAR_14_MEM_]] : tensor<512x256xbf16>) outs([[VAR_20_]] : tensor<256xbf16>) dimensions = [0] diff --git a/test/Conversion/TritonArithToLinalg/reducesum_512_256_bf16_axis1.mlir b/test/Conversion/TritonArithToLinalg/reducesum_512_256_bf16_axis1.mlir index 519da08c..1209460c 100644 --- a/test/Conversion/TritonArithToLinalg/reducesum_512_256_bf16_axis1.mlir +++ b/test/Conversion/TritonArithToLinalg/reducesum_512_256_bf16_axis1.mlir @@ -36,6 +36,7 @@ module { // CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0, d1) -> (0, d1)> // CHECK-LABEL: func.func @kernel // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : bf16 // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : i32 // CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<512xi32> // CHECK-NOT: separator of consecutive DAGs @@ -94,7 +95,6 @@ module { // CHECK-DAG: [[VAR_19_:%.+]] = tensor.empty() : tensor<256x512xbf16> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_transposed_:%.+]] = linalg.transpose ins([[LOAD_VAR_14_MEM_]] : tensor<512x256xbf16>) outs([[VAR_19_]] : tensor<256x512xbf16>) permutation = [1, 0] -// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : bf16 // CHECK-DAG: [[VAR_20_:%.+]] = tensor.empty() : tensor<512xbf16> // CHECK: [[VAR_21_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : bf16) outs([[VAR_20_]] : tensor<512xbf16>) -> tensor<512xbf16> // CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_transposed_]] : tensor<256x512xbf16>) outs([[VAR_21_]] : tensor<512xbf16>) dimensions = [0] diff --git a/test/Conversion/TritonArithToLinalg/reducesum_512_256_f32_axis0.mlir b/test/Conversion/TritonArithToLinalg/reducesum_512_256_f32_axis0.mlir index 73e54912..6994114a 100644 --- a/test/Conversion/TritonArithToLinalg/reducesum_512_256_f32_axis0.mlir +++ b/test/Conversion/TritonArithToLinalg/reducesum_512_256_f32_axis0.mlir @@ -36,6 +36,7 @@ module { // CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0, d1) -> (0, d1)> // CHECK-LABEL: func.func @kernel // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : i32 // CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<512xi32> // CHECK-NOT: separator of consecutive DAGs @@ -91,7 +92,6 @@ module { // CHECK: linalg.yield [[VAR_21_5_]] : !tt.ptr // CHECK: } -> tensor<256x!tt.ptr> // CHECK-DAG: [[LOAD_VAR_14_MEM_:%.+]] = tt.load [[VAR_14_]] : tensor<512x256x!tt.ptr> -// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[VAR_19_:%.+]] = tensor.empty() : tensor<256xf32> // CHECK: [[VAR_20_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[VAR_19_]] : tensor<256xf32>) -> tensor<256xf32> // CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[LOAD_VAR_14_MEM_]] : tensor<512x256xf32>) outs([[VAR_20_]] : tensor<256xf32>) dimensions = [0] diff --git a/test/Conversion/TritonArithToLinalg/reducesum_512_256_f32_axis1.mlir b/test/Conversion/TritonArithToLinalg/reducesum_512_256_f32_axis1.mlir index 5024a2da..a0459ce4 100644 --- a/test/Conversion/TritonArithToLinalg/reducesum_512_256_f32_axis1.mlir +++ b/test/Conversion/TritonArithToLinalg/reducesum_512_256_f32_axis1.mlir @@ -36,6 +36,7 @@ module { // CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0, d1) -> (0, d1)> // CHECK-LABEL: func.func @kernel // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : i32 // CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<512xi32> // CHECK-NOT: separator of consecutive DAGs @@ -94,7 +95,6 @@ module { // CHECK-DAG: [[VAR_19_:%.+]] = tensor.empty() : tensor<256x512xf32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_transposed_:%.+]] = linalg.transpose ins([[LOAD_VAR_14_MEM_]] : tensor<512x256xf32>) outs([[VAR_19_]] : tensor<256x512xf32>) permutation = [1, 0] -// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[VAR_20_:%.+]] = tensor.empty() : tensor<512xf32> // CHECK: [[VAR_21_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[VAR_20_]] : tensor<512xf32>) -> tensor<512xf32> // CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_transposed_]] : tensor<256x512xf32>) outs([[VAR_21_]] : tensor<512xf32>) dimensions = [0] diff --git a/test/Conversion/TritonArithToLinalg/reducesum_middle_dim.mlir b/test/Conversion/TritonArithToLinalg/reducesum_middle_dim.mlir index a140e308..f972c47f 100644 --- a/test/Conversion/TritonArithToLinalg/reducesum_middle_dim.mlir +++ b/test/Conversion/TritonArithToLinalg/reducesum_middle_dim.mlir @@ -47,6 +47,7 @@ module { // CHECK-DAG: [[MAP_6_:#.+]] = affine_map<(d0, d1, d2) -> (0, d1, d2)> // CHECK-LABEL: func.func @kernel // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: tensor<32x16x!tt.ptr>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : bf16 // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : i32 // CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<32xi32> // CHECK-NOT: separator of consecutive DAGs @@ -131,7 +132,6 @@ module { // CHECK: linalg.yield [[VAR_29_6_]] : !tt.ptr // CHECK: } -> tensor<32x256x16x!tt.ptr> // CHECK-DAG: [[LOAD_VAR_25_MEM_:%.+]] = tt.load [[VAR_25_]] : tensor<32x256x16x!tt.ptr> -// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : bf16 // CHECK-DAG: [[VAR_27_:%.+]] = tensor.empty() : tensor<32x16xbf16> // CHECK: [[VAR_28_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : bf16) outs([[VAR_27_]] : tensor<32x16xbf16>) -> tensor<32x16xbf16> // CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[LOAD_VAR_25_MEM_]] : tensor<32x256x16xbf16>) outs([[VAR_28_]] : tensor<32x16xbf16>) dimensions = [1] diff --git a/test/Conversion/TritonArithToLinalg/reducesum_scalar.mlir b/test/Conversion/TritonArithToLinalg/reducesum_scalar.mlir index dd06782a..c7050160 100644 --- a/test/Conversion/TritonArithToLinalg/reducesum_scalar.mlir +++ b/test/Conversion/TritonArithToLinalg/reducesum_scalar.mlir @@ -18,6 +18,7 @@ module { // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func.func @kernel // CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { +// CHECK: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK: [[VAR_0_:%.+]] = tensor.empty() : tensor<128xi32> // CHECK: [[VAR_1_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_0_]] : tensor<128xi32>) { // CHECK: ^bb0([[out_:%.+]]: i32): @@ -33,7 +34,6 @@ module { // CHECK: linalg.yield [[VAR_8_1_]] : !tt.ptr // CHECK: } -> tensor<128x!tt.ptr> // CHECK-DAG: [[LOAD_VAR_4_MEM_:%.+]] = tt.load [[VAR_4_]] : tensor<128x!tt.ptr> -// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[VAR_6_:%.+]] = bufferization.alloc_tensor() : tensor // CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_0_dot_000000_]] into [[VAR_6_]][] : tensor // CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[LOAD_VAR_4_MEM_]] : tensor<128xbf16>) outs([[VAR_inserted_]] : tensor) dimensions = [0] diff --git a/test/Conversion/TritonArithToLinalg/split.mlir b/test/Conversion/TritonArithToLinalg/split.mlir new file mode 100644 index 00000000..de35f543 --- /dev/null +++ b/test/Conversion/TritonArithToLinalg/split.mlir @@ -0,0 +1,66 @@ +// RUN: triton-shared-opt --triton-arith-to-linalg %s | FileCheck %s + +module { + tt.func public @kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<256x!tt.ptr>, tensor<256xi32> + %3 = tt.load %2 : tensor<256x!tt.ptr> + %4 = tt.reshape %3 {allow_reorder = false} : tensor<256xi32> -> tensor<128x2xi32> + %outLHS, %outRHS = tt.split %4 : tensor<128x2xi32> -> tensor<128xi32> + %5 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> + %7 = tt.addptr %6, %5 : tensor<128x!tt.ptr>, tensor<128xi32> + tt.store %7, %outLHS : tensor<128x!tt.ptr> + %8 = tt.splat %arg2 : !tt.ptr -> tensor<128x!tt.ptr> + %9 = tt.addptr %8, %5 : tensor<128x!tt.ptr>, tensor<128xi32> + tt.store %9, %outRHS : tensor<128x!tt.ptr> + tt.return + } +} + +// CHECK: func.func @kernel(%arg0: !tt.ptr {{.*}}, %arg1: !tt.ptr {{.*}}, %arg2: !tt.ptr {{.*}}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { +// CHECK: [[EMPTY256:%.+]] = tensor.empty() : tensor<256xi32> +// CHECK: [[RANGE256:%.+]] = linalg.generic {{.*}} outs([[EMPTY256]] : tensor<256xi32>) { +// CHECK: ^bb0(%out: i32): +// CHECK: [[IDX0:%.+]] = linalg.index 0 : index +// CHECK: [[I32_0:%.+]] = arith.index_cast [[IDX0]] : index to i32 +// CHECK: linalg.yield [[I32_0]] : i32 +// CHECK: } -> tensor<256xi32> +// CHECK: [[EMPTY_PTR256:%.+]] = tensor.empty() : tensor<256x!tt.ptr> +// CHECK: [[SPLAT_ARG0:%.+]] = linalg.fill ins(%arg0 : !tt.ptr) outs([[EMPTY_PTR256]] : tensor<256x!tt.ptr>) -> tensor<256x!tt.ptr> +// CHECK: [[ADDPTR256:%.+]] = linalg.generic {{.*}} ins([[SPLAT_ARG0]], [[RANGE256]] : tensor<256x!tt.ptr>, tensor<256xi32>) outs([[SPLAT_ARG0]] : tensor<256x!tt.ptr>) { +// CHECK: ^bb0([[IN_PTR:%.+]]: !tt.ptr, [[IN_I32:%.+]]: i32, [[OUT_PTR:%.+]]: !tt.ptr): +// CHECK: [[NEW_PTR:%.+]] = tt.addptr [[IN_PTR]], [[IN_I32]] : !tt.ptr, i32 +// CHECK: linalg.yield [[NEW_PTR]] : !tt.ptr +// CHECK: } -> tensor<256x!tt.ptr> +// CHECK: [[LOADED256:%.+]] = tt.load [[ADDPTR256]] : tensor<256x!tt.ptr> +// CHECK: [[RESHAPED:%.+]] = tensor.expand_shape [[LOADED256]] +// CHECK: [[SLICE_LHS:%.+]] = tensor.extract_slice [[RESHAPED]]{{\[}}0, 0{{\]}} [128, 1] [1, 1] : tensor<128x2xi32> to tensor<128xi32> +// CHECK: [[SLICE_RHS:%.+]] = tensor.extract_slice [[RESHAPED]]{{\[}}0, 1{{\]}} [128, 1] [1, 1] : tensor<128x2xi32> to tensor<128xi32> +// CHECK: [[EMPTY128:%.+]] = tensor.empty() : tensor<128xi32> +// CHECK: [[RANGE128:%.+]] = linalg.generic {{.*}} outs([[EMPTY128]] : tensor<128xi32>) { +// CHECK: ^bb0(%out: i32): +// CHECK: [[IDX1:%.+]] = linalg.index 0 : index +// CHECK: [[I32_1:%.+]] = arith.index_cast [[IDX1]] : index to i32 +// CHECK: linalg.yield [[I32_1]] : i32 +// CHECK: } -> tensor<128xi32> +// CHECK: [[EMPTY_PTR128_1:%.+]] = tensor.empty() : tensor<128x!tt.ptr> +// CHECK: [[SPLAT_ARG1:%.+]] = linalg.fill ins(%arg1 : !tt.ptr) outs([[EMPTY_PTR128_1]] : tensor<128x!tt.ptr>) -> tensor<128x!tt.ptr> +// CHECK: [[ADDPTR128_1:%.+]] = linalg.generic {{.*}} ins([[SPLAT_ARG1]], [[RANGE128]] : tensor<128x!tt.ptr>, tensor<128xi32>) outs([[SPLAT_ARG1]] : tensor<128x!tt.ptr>) { +// CHECK: ^bb0([[IN_PTR:%.+]]: !tt.ptr, [[IN_I32:%.+]]: i32, [[OUT_PTR:%.+]]: !tt.ptr): +// CHECK: [[NEW_PTR1:%.+]] = tt.addptr [[IN_PTR]], [[IN_I32]] : !tt.ptr, i32 +// CHECK: linalg.yield [[NEW_PTR1]] : !tt.ptr +// CHECK: } -> tensor<128x!tt.ptr> +// CHECK: tt.store [[ADDPTR128_1]], [[SLICE_LHS]] : tensor<128x!tt.ptr> +// CHECK: [[EMPTY_PTR128_2:%.+]] = tensor.empty() : tensor<128x!tt.ptr> +// CHECK: [[SPLAT_ARG2:%.+]] = linalg.fill ins(%arg2 : !tt.ptr) outs([[EMPTY_PTR128_2]] : tensor<128x!tt.ptr>) -> tensor<128x!tt.ptr> +// CHECK: [[ADDPTR128_2:%.+]] = linalg.generic {{.*}} ins([[SPLAT_ARG2]], [[RANGE128]] : tensor<128x!tt.ptr>, tensor<128xi32>) outs([[SPLAT_ARG2]] : tensor<128x!tt.ptr>) { +// CHECK: ^bb0([[IN_PTR:%.+]]: !tt.ptr, [[IN_I32:%.+]]: i32, [[OUT_PTR:%.+]]: !tt.ptr): +// CHECK: [[NEW_PTR2:%.+]] = tt.addptr [[IN_PTR]], [[IN_I32]] : !tt.ptr, i32 +// CHECK: linalg.yield [[NEW_PTR2]] : !tt.ptr +// CHECK: } -> tensor<128x!tt.ptr> +// CHECK: tt.store [[ADDPTR128_2]], [[SLICE_RHS]] : tensor<128x!tt.ptr> +// CHECK: return +// CHECK: } +// CHECK: } \ No newline at end of file diff --git a/test/Conversion/TritonArithToLinalg/triton_assert.mlir b/test/Conversion/TritonArithToLinalg/triton_assert.mlir index cd6d6306..66929b60 100644 --- a/test/Conversion/TritonArithToLinalg/triton_assert.mlir +++ b/test/Conversion/TritonArithToLinalg/triton_assert.mlir @@ -11,8 +11,6 @@ tt.func public @assert_lol(%arg0: i32) { // CHECK-SAME: ([[PARAM_0_:%.+]]: i32, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) { // CHECK: [[CST_0_:%.+]] = arith.constant 0 : i32 // CHECK-DAG: [[VAR_0_:%.+]] = arith.cmpi sgt, [[PARAM_0_]], [[CST_0_]] : i32 -// CHECK-DAG: [[VAR_1_:%.+]] = tensor.empty() : tensor<1xi1> -// CHECK: [[VAR_2_:%.+]] = linalg.fill ins([[VAR_0_]] : i1) outs([[VAR_1_]] : tensor<1xi1>) -> tensor<1xi1> // CHECK: cf.assert [[VAR_0_]], ".py:0: Assertion `lol` failed" // CHECK: return // CHECK: }