Skip to content

Commit

Permalink
Implementing split, join, and cat (#165)
Browse files Browse the repository at this point in the history
Co-authored-by: Renat Idrisov <parsifal-47@users.noreply.github.com>
  • Loading branch information
parsifal-47 and parsifal-47 authored Dec 5, 2024
1 parent 5b17b80 commit d5b7bee
Show file tree
Hide file tree
Showing 25 changed files with 325 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,97 @@ struct PreciseDivConverter : public OpConversionPattern<triton::PreciseDivFOp> {
}
};

struct CatConverter : public OpConversionPattern<triton::CatOp> {
using OpConversionPattern<triton::CatOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(triton::CatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto replacement = rewriter.create<tensor::ConcatOp>(
op.getLoc(), 0 /* concat dimension */, adaptor.getOperands());

rewriter.replaceOp(op, replacement);

return success();
}
};

struct SplitConverter : public OpConversionPattern<triton::SplitOp> {
using OpConversionPattern<triton::SplitOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(triton::SplitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.getOperand();
auto inputType = cast<RankedTensorType>(input.getType());

Type resultType = op.getResults().front().getType();
auto resultTensor = cast<RankedTensorType>(resultType);
auto shape = inputType.getShape();

SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> sizes =
llvm::to_vector(llvm::map_range(shape, [&](int64_t dim) -> OpFoldResult {
return rewriter.getIndexAttr(dim);
}));

SmallVector<Value> results;

for (int i = 0; i < 2; ++i) {
offsets.pop_back();
sizes.pop_back();

offsets.push_back(rewriter.getIndexAttr(i));
sizes.push_back(rewriter.getIndexAttr(1));
Value slice = rewriter.create<tensor::ExtractSliceOp>(
loc, resultTensor, input, offsets, sizes, strides);
results.push_back(slice);
}

rewriter.replaceOp(op, results);
return success();
}
};

struct JoinConverter : public OpConversionPattern<triton::JoinOp> {
using OpConversionPattern<triton::JoinOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(triton::JoinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ValueRange inputs = op.getOperands();

auto resultType = cast<RankedTensorType>(op.getResult().getType());

auto loc = op.getLoc();
Value result = rewriter.create<tensor::EmptyOp>(loc, resultType.getShape(), resultType.getElementType());

auto shape = resultType.getShape();

SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> sizes =
llvm::to_vector(llvm::map_range(shape, [&](int64_t dim) -> OpFoldResult {
return rewriter.getIndexAttr(dim);
}));

for (int i = 0; i < 2; ++i) {
offsets.pop_back();
sizes.pop_back();

offsets.push_back(rewriter.getIndexAttr(i));
sizes.push_back(rewriter.getIndexAttr(1));
result = rewriter.create<tensor::InsertSliceOp>(loc, inputs[i], result, offsets, sizes, strides);
}

rewriter.replaceOp(op, result);

return success();
}
};

struct MulHiUIOpConverter : public OpConversionPattern<triton::MulhiUIOp> {
using OpConversionPattern<triton::MulhiUIOp>::OpConversionPattern;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Dialect/Triton/IR/Types.h"
Expand Down Expand Up @@ -224,7 +225,6 @@ class StructuredToMemrefPass

LogicalResult convertArgsToMemrefType() {
auto moduleOp = getOperation();

RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
TritonFunctionSignatureConverter typeConverter;
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ void mlir::triton::populateTritonArithToLinalgConversionPatterns(
patterns.add<MulHiUIOpConverter>(patterns.getContext());
patterns.add<PreciseSqrtConverter>(patterns.getContext());
patterns.add<PreciseDivConverter>(patterns.getContext());
patterns.add<CatConverter>(patterns.getContext());
patterns.add<SplitConverter>(patterns.getContext());
patterns.add<JoinConverter>(patterns.getContext());
patterns.add<FpToFpConverter>(patterns.getContext());
patterns.add<ClampConverter>(patterns.getContext());
patterns.add<MatmulConverter>(patterns.getContext());
Expand Down
18 changes: 18 additions & 0 deletions lib/Conversion/TritonArithToLinalg/TritonArithToLinalgPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand Down Expand Up @@ -73,6 +74,19 @@ class TritonArithToLinalgPass
}
}

LogicalResult applyTensorConcatDecomposition() {
auto moduleOp = getOperation();
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);

tensor::populateDecomposeTensorConcatPatterns(patterns);

if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) {
return failure();
}
return success();
}

public:
void getDependentDialects(DialectRegistry &registry) const override {
registry
Expand Down Expand Up @@ -161,6 +175,10 @@ class TritonArithToLinalgPass
signalPassFailure();
}

if (failed(applyTensorConcatDecomposition())) {
signalPassFailure();
}

// Convert tt.func and tt.return into func's counterparts
if (ttToFuncFunc) {
moduleOp.walk([&](triton::FuncOp func) {
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TritonToLinalg/TritonToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ void mlir::triton::populateTritonToLinalgConversionPatterns(
patterns.add<MulHiUIOpConverter>(patterns.getContext());
patterns.add<PreciseSqrtConverter>(patterns.getContext());
patterns.add<PreciseDivConverter>(patterns.getContext());
patterns.add<CatConverter>(patterns.getContext());
patterns.add<SplitConverter>(patterns.getContext());
patterns.add<JoinConverter>(patterns.getContext());
patterns.add<FpToFpConverter>(patterns.getContext());
patterns.add<ClampConverter>(patterns.getContext());
patterns.add<AssertConverter>(patterns.getContext());
Expand Down
3 changes: 0 additions & 3 deletions python/examples/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def device(request):

tests_not_supported = {
"test_bin_op",
"test_split",
"test_split_to_scalar",
"test_interleave_scalars",
"test_pointer_arguments",
Expand All @@ -34,7 +33,6 @@ def device(request):
"test_ptx_cast",
"test_compare_op",
"test_maxnreg",
"test_join",
"test_join_scalars",
"test_join_with_mma",
"test_interleave",
Expand All @@ -51,7 +49,6 @@ def device(request):
"test_atomic_cas",
"test_tensor_atomic_cas",
"test_cast",
"test_cat",
"test_store_constant",
"test_reduce",
"test_reduce1d",
Expand Down
6 changes: 3 additions & 3 deletions test/Conversion/TritonArithToLinalg/block_ptr_advance.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ module {
// CHECK-LABEL: func.func @matmul_kernel_with_block_pointers_01234567891011
// CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr<bf16>, [[PARAM_1_:%.+]]: !tt.ptr<bf16>, [[PARAM_2_:%.+]]: !tt.ptr<bf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32, [[PARAM_14_:%.+]]: i32, [[PARAM_15_:%.+]]: i32, [[PARAM_16_:%.+]]: i32, [[PARAM_17_:%.+]]: i32, [[PARAM_18_:%.+]]: i32, [[PARAM_19_:%.+]]: i32) {
// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : bf16
// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<128x64xbf16>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : bf16) outs([[VAR_0_]] : tensor<128x64xbf16>) -> tensor<128x64xbf16>
// CHECK-DAG: [[CST_64_:%.+]] = arith.constant 64 : i32
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32
// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : i32
// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<128x64xbf16>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : bf16) outs([[VAR_0_]] : tensor<128x64xbf16>) -> tensor<128x64xbf16>
// CHECK-DAG: [[VAR_2_:%.+]] = arith.extsi [[PARAM_3_]] : i32 to i64
// CHECK-DAG: [[VAR_3_:%.+]] = arith.extsi [[PARAM_5_]] : i32 to i64
// CHECK-DAG: [[VAR_4_:%.+]] = arith.extsi [[PARAM_6_]] : i32 to i64
Expand Down
3 changes: 1 addition & 2 deletions test/Conversion/TritonArithToLinalg/convert_addi_reduce.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ module {
// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4096xi32>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_0_]] : tensor<4096xi32>) -> tensor<4096xi32>
// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32
// CHECK-DAG: [[VAR_2_:%.+]] = bufferization.alloc_tensor() : tensor<i32>
// CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_0_1_]] into [[VAR_2_]][] : tensor<i32>
// CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_0_]] into [[VAR_2_]][] : tensor<i32>
// CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_1_]] : tensor<4096xi32>) outs([[VAR_inserted_]] : tensor<i32>) dimensions = [0]
// CHECK: ([[in_:%.+]]: i32, [[in_]]it: i32) {
// CHECK: [[VAR_3_:%.+]] = arith.addi [[in_]], [[in_]]it : i32
Expand Down
26 changes: 12 additions & 14 deletions test/Conversion/TritonArithToLinalg/convert_argmin_argmax.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ module {
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func.func @argmax_012
// CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr<f32>, [[PARAM_1_:%.+]]: !tt.ptr<i32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) {
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32
// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32
// CHECK-DAG: [[VAR_0_:%.+]] = arith.muli [[PARAM_6_]], [[PARAM_2_]] : i32
// CHECK-DAG: [[VAR_1_:%.+]] = tensor.empty() : tensor<4096xi32>
// CHECK: [[VAR_2_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_1_]] : tensor<4096xi32>) {
Expand All @@ -71,20 +73,18 @@ module {
// CHECK: [[VAR_3_:%.+]] = tensor.empty() : tensor<4096xi32>
// CHECK: [[VAR_4_:%.+]] = linalg.fill ins([[VAR_0_]] : i32) outs([[VAR_3_]] : tensor<4096xi32>) -> tensor<4096xi32>
// CHECK: [[VAR_5_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_4_]], [[VAR_2_]] : tensor<4096xi32>, tensor<4096xi32>) outs([[VAR_4_]] : tensor<4096xi32>) {
// CHECK: ^bb0([[in_:%.+]]: i32, [[in_]]_1: i32, [[out_]]: i32):
// CHECK: [[VAR_15_1_:%.+]] = arith.addi [[in_]], [[in_]]_1 : i32
// CHECK: ^bb0([[in_:%.+]]: i32, [[in_1:%.+]]: i32, [[out_]]: i32):
// CHECK: [[VAR_15_1_:%.+]] = arith.addi [[in_]], [[in_1]] : i32
// CHECK: linalg.yield [[VAR_15_1_]] : i32
// CHECK: } -> tensor<4096xi32>
// CHECK: [[VAR_6_:%.+]] = tensor.empty() : tensor<4096x!tt.ptr<f32>>
// CHECK: [[VAR_7_:%.+]] = linalg.fill ins([[PARAM_0_]] : !tt.ptr<f32>) outs([[VAR_6_]] : tensor<4096x!tt.ptr<f32>>) -> tensor<4096x!tt.ptr<f32>>
// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]], [[VAR_5_]] : tensor<4096x!tt.ptr<f32>>, tensor<4096xi32>) outs([[VAR_7_]] : tensor<4096x!tt.ptr<f32>>) {
// CHECK: ^bb0([[in_]]: !tt.ptr<f32>, [[in_]]_1: i32, [[out_]]: !tt.ptr<f32>):
// CHECK: [[VAR_15_2_:%.+]] = tt.addptr [[in_]], [[in_]]_1 : !tt.ptr<f32>, i32
// CHECK: ^bb0([[in_]]: !tt.ptr<f32>, [[in_1:%.+]]: i32, [[out_]]: !tt.ptr<f32>):
// CHECK: [[VAR_15_2_:%.+]] = tt.addptr [[in_]], [[in_1]] : !tt.ptr<f32>, i32
// CHECK: linalg.yield [[VAR_15_2_]] : !tt.ptr<f32>
// CHECK: } -> tensor<4096x!tt.ptr<f32>>
// CHECK-DAG: [[LOAD_VAR_8_MEM_:%.+]] = tt.load [[VAR_8_]] : tensor<4096x!tt.ptr<f32>>
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32
// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32
// CHECK-DAG: [[VAR_10_:%.+]] = tensor.empty() : tensor<f32>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_11_:%.+]] = linalg.fill ins([[CST_0_]] : f32) outs([[VAR_10_]] : tensor<f32>) -> tensor<f32>
Expand All @@ -102,7 +102,6 @@ module {
// CHECK-DAG: [[VAR_21_:%.+]] = arith.select [[VAR_19_]], [[in_1_]], [[init_2_]] : i32
// CHECK: linalg.yield [[VAR_20_]], [[VAR_21_]] : f32, i32
// CHECK: }
// CHECK-DAG: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]]#0[] : tensor<f32>
// CHECK-DAG: [[VAR_extracted_0_:%.+]] = tensor.extract [[VAR_reduced_]]#1[] : tensor<i32>
// CHECK-DAG: [[VAR_14_:%.+]] = tt.addptr [[PARAM_1_]], [[PARAM_6_]] : !tt.ptr<i32>, i32
// CHECK: tt.store [[VAR_14_]], [[VAR_extracted_0_]] : !tt.ptr<i32>
Expand All @@ -112,6 +111,8 @@ module {
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func.func @argmin_012
// CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr<f32>, [[PARAM_1_:%.+]]: !tt.ptr<i32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) {
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0x7F800000 : f32
// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32
// CHECK-DAG: [[VAR_0_:%.+]] = arith.muli [[PARAM_6_]], [[PARAM_2_]] : i32
// CHECK-DAG: [[VAR_1_:%.+]] = tensor.empty() : tensor<4096xi32>
// CHECK: [[VAR_2_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_1_]] : tensor<4096xi32>) {
Expand All @@ -123,20 +124,18 @@ module {
// CHECK: [[VAR_3_:%.+]] = tensor.empty() : tensor<4096xi32>
// CHECK: [[VAR_4_:%.+]] = linalg.fill ins([[VAR_0_]] : i32) outs([[VAR_3_]] : tensor<4096xi32>) -> tensor<4096xi32>
// CHECK: [[VAR_5_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_4_]], [[VAR_2_]] : tensor<4096xi32>, tensor<4096xi32>) outs([[VAR_4_]] : tensor<4096xi32>) {
// CHECK: ^bb0([[in_:%.+]]: i32, [[in_]]_1: i32, [[out_]]: i32):
// CHECK: [[VAR_15_1_:%.+]] = arith.addi [[in_]], [[in_]]_1 : i32
// CHECK: ^bb0([[in_:%.+]]: i32, [[in_1:%.+]]: i32, [[out_]]: i32):
// CHECK: [[VAR_15_1_:%.+]] = arith.addi [[in_]], [[in_1]] : i32
// CHECK: linalg.yield [[VAR_15_1_]] : i32
// CHECK: } -> tensor<4096xi32>
// CHECK: [[VAR_6_:%.+]] = tensor.empty() : tensor<4096x!tt.ptr<f32>>
// CHECK: [[VAR_7_:%.+]] = linalg.fill ins([[PARAM_0_]] : !tt.ptr<f32>) outs([[VAR_6_]] : tensor<4096x!tt.ptr<f32>>) -> tensor<4096x!tt.ptr<f32>>
// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]], [[VAR_5_]] : tensor<4096x!tt.ptr<f32>>, tensor<4096xi32>) outs([[VAR_7_]] : tensor<4096x!tt.ptr<f32>>) {
// CHECK: ^bb0([[in_]]: !tt.ptr<f32>, [[in_]]_1: i32, [[out_]]: !tt.ptr<f32>):
// CHECK: [[VAR_15_2_:%.+]] = tt.addptr [[in_]], [[in_]]_1 : !tt.ptr<f32>, i32
// CHECK: ^bb0([[in_]]: !tt.ptr<f32>, [[in_1:%.+]]: i32, [[out_]]: !tt.ptr<f32>):
// CHECK: [[VAR_15_2_:%.+]] = tt.addptr [[in_]], [[in_1]] : !tt.ptr<f32>, i32
// CHECK: linalg.yield [[VAR_15_2_]] : !tt.ptr<f32>
// CHECK: } -> tensor<4096x!tt.ptr<f32>>
// CHECK-DAG: [[LOAD_VAR_8_MEM_:%.+]] = tt.load [[VAR_8_]] : tensor<4096x!tt.ptr<f32>>
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0x7F800000 : f32
// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32
// CHECK-DAG: [[VAR_10_:%.+]] = tensor.empty() : tensor<f32>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_11_:%.+]] = linalg.fill ins([[CST_0_]] : f32) outs([[VAR_10_]] : tensor<f32>) -> tensor<f32>
Expand All @@ -154,7 +153,6 @@ module {
// CHECK-DAG: [[VAR_21_:%.+]] = arith.select [[VAR_19_]], [[in_1_]], [[init_2_]] : i32
// CHECK: linalg.yield [[VAR_20_]], [[VAR_21_]] : f32, i32
// CHECK: }
// CHECK-DAG: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]]#0[] : tensor<f32>
// CHECK-DAG: [[VAR_extracted_0_:%.+]] = tensor.extract [[VAR_reduced_]]#1[] : tensor<i32>
// CHECK-DAG: [[VAR_14_:%.+]] = tt.addptr [[PARAM_1_]], [[PARAM_6_]] : !tt.ptr<i32>, i32
// CHECK: tt.store [[VAR_14_]], [[VAR_extracted_0_]] : !tt.ptr<i32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ module {
// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0, d1) -> (0, d1)>
// CHECK-LABEL: func.func @test_argmax
// CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr<f32>, [[PARAM_1_:%.+]]: !tt.ptr<f32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) {
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32
// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32
// CHECK: [[VAR_0_:%.+]] = tensor.empty() : tensor<4xi32>
// CHECK: [[VAR_1_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_0_]] : tensor<4xi32>) {
// CHECK: ^bb0([[out_:%.+]]: i32):
Expand Down Expand Up @@ -164,9 +166,7 @@ module {
// CHECK: ^bb0([[in_]]: i32, [[out_]]: i32):
// CHECK: linalg.yield [[in_]] : i32
// CHECK: } -> tensor<4x4xi32>
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32
// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32
// CHECK-DAG: [[VAR_19_:%.+]] = tensor.empty() : tensor<4xf32>
// CHECK: [[VAR_19_:%.+]] = tensor.empty() : tensor<4xf32>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_20_:%.+]] = linalg.fill ins([[CST_0_]] : f32) outs([[VAR_19_]] : tensor<4xf32>) -> tensor<4xf32>
// CHECK-DAG: [[VAR_21_:%.+]] = tensor.empty() : tensor<4xi32>
Expand Down Expand Up @@ -205,6 +205,8 @@ module {
// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0, d1) -> (0, d1)>
// CHECK-LABEL: func.func @test_argmin
// CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr<f32>, [[PARAM_1_:%.+]]: !tt.ptr<f32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) {
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0x7F800000 : f32
// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32
// CHECK: [[VAR_0_:%.+]] = tensor.empty() : tensor<4xi32>
// CHECK: [[VAR_1_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_0_]] : tensor<4xi32>) {
// CHECK: ^bb0([[out_]]: i32):
Expand Down Expand Up @@ -256,9 +258,7 @@ module {
// CHECK: ^bb0([[in_]]: i32, [[out_]]: i32):
// CHECK: linalg.yield [[in_]] : i32
// CHECK: } -> tensor<4x4xi32>
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0x7F800000 : f32
// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32
// CHECK-DAG: [[VAR_19_:%.+]] = tensor.empty() : tensor<4xf32>
// CHECK: [[VAR_19_:%.+]] = tensor.empty() : tensor<4xf32>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_20_:%.+]] = linalg.fill ins([[CST_0_]] : f32) outs([[VAR_19_]] : tensor<4xf32>) -> tensor<4xf32>
// CHECK-DAG: [[VAR_21_:%.+]] = tensor.empty() : tensor<4xi32>
Expand Down
Loading

0 comments on commit d5b7bee

Please sign in to comment.