Skip to content

Commit

Permalink
Addressing Code Review Feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
parsifal-47 committed Nov 28, 2024
1 parent 289274e commit c9cf66b
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,33 +150,6 @@ static std::optional<unsigned> getBitWidth(Type a) {
return std::nullopt;
}

static void processTwoStridesLastDim(RankedTensorType type, ConversionPatternRewriter &rewriter,
std::function<void(SmallVector<OpFoldResult> &/* offsets */,
SmallVector<OpFoldResult>&/* sizes */,
SmallVector<OpFoldResult>&/* strides */,
int /* index */)> op) {
int64_t rank = type.getRank();
auto shape = type.getShape();

SmallVector<OpFoldResult> offsets, sizes, strides;
SmallVector<int64_t> stridesInt(rank, 1);

for (size_t j = 0; j < shape.size(); ++j) {
offsets.push_back(rewriter.getIndexAttr(0));
sizes.push_back(rewriter.getIndexAttr(shape[j]));
strides.push_back(rewriter.getIndexAttr(stridesInt[j]));
}

for (int i = 0; i < 2; ++i) {
offsets.pop_back();
sizes.pop_back();

offsets.push_back(rewriter.getIndexAttr(i));
sizes.push_back(rewriter.getIndexAttr(1));
op(offsets, sizes, strides, i);
}
}

//===----------------------------------------------------------------------===//
// Op Lowering Patterns
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1003,16 +976,27 @@ struct SplitConverter : public OpConversionPattern<triton::SplitOp> {

Type resultType = op.getResults().front().getType();
auto resultTensor = cast<RankedTensorType>(resultType);

SmallVector<Value, 2> results;
processTwoStridesLastDim(inputType, rewriter,
[&rewriter, &results, &loc, &resultTensor, &input](SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides, int index) {
Value slice = rewriter.create<tensor::ExtractSliceOp>(
loc, resultTensor, input, offsets, sizes, strides);
results.push_back(slice);
});
auto shape = inputType.getShape();

SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> sizes =
llvm::to_vector(llvm::map_range(shape, [&](int64_t dim) -> OpFoldResult {
return rewriter.getIndexAttr(dim);
}));

SmallVector<Value> results;

for (int i = 0; i < 2; ++i) {
offsets.pop_back();
sizes.pop_back();

offsets.push_back(rewriter.getIndexAttr(i));
sizes.push_back(rewriter.getIndexAttr(1));
Value slice = rewriter.create<tensor::ExtractSliceOp>(
loc, resultTensor, input, offsets, sizes, strides);
results.push_back(slice);
}

rewriter.replaceOp(op, results);
return success();
Expand All @@ -1032,12 +1016,23 @@ struct JoinConverter : public OpConversionPattern<triton::JoinOp> {
auto loc = op.getLoc();
Value result = rewriter.create<tensor::EmptyOp>(loc, resultType.getShape(), resultType.getElementType());

processTwoStridesLastDim(resultType, rewriter,
[&rewriter, &result, &loc, &inputs](SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides, int index) {
result = rewriter.create<tensor::InsertSliceOp>(loc, inputs[index], result, offsets, sizes, strides);
});
auto shape = resultType.getShape();

SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> sizes =
llvm::to_vector(llvm::map_range(shape, [&](int64_t dim) -> OpFoldResult {
return rewriter.getIndexAttr(dim);
}));

for (int i = 0; i < 2; ++i) {
offsets.pop_back();
sizes.pop_back();

offsets.push_back(rewriter.getIndexAttr(i));
sizes.push_back(rewriter.getIndexAttr(1));
result = rewriter.create<tensor::InsertSliceOp>(loc, inputs[i], result, offsets, sizes, strides);
}

rewriter.replaceOp(op, result);

Expand Down
95 changes: 95 additions & 0 deletions test/Conversion/TritonArithToLinalg/join.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// RUN: triton-shared-opt --triton-arith-to-linalg %s | FileCheck %s

module {
tt.func public @kernel(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32},
%arg2: !tt.ptr<i32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<2> : tensor<128x1xi32>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
%1 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<128x!tt.ptr<i32>>
%2 = tt.addptr %1, %0 : tensor<128x!tt.ptr<i32>>, tensor<128xi32>
%3 = tt.load %2 : tensor<128x!tt.ptr<i32>>
%4 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<128x!tt.ptr<i32>>
%5 = tt.addptr %4, %0 : tensor<128x!tt.ptr<i32>>, tensor<128xi32>
%6 = tt.load %5 : tensor<128x!tt.ptr<i32>>
%7 = tt.join %3, %6 : tensor<128xi32> -> tensor<128x2xi32>
%8 = tt.expand_dims %0 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32>
%9 = arith.muli %8, %cst : tensor<128x1xi32>
%10 = tt.splat %arg2 : !tt.ptr<i32> -> tensor<128x1x!tt.ptr<i32>>
%11 = tt.addptr %10, %9 : tensor<128x1x!tt.ptr<i32>>, tensor<128x1xi32>
%12 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32>
%13 = tt.expand_dims %12 {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32>
%14 = tt.broadcast %11 : tensor<128x1x!tt.ptr<i32>> -> tensor<128x2x!tt.ptr<i32>>
%15 = tt.broadcast %13 : tensor<1x2xi32> -> tensor<128x2xi32>
%16 = tt.addptr %14, %15 : tensor<128x2x!tt.ptr<i32>>, tensor<128x2xi32>
tt.store %16, %7 : tensor<128x2x!tt.ptr<i32>>
tt.return
}
}

// CHECK: func.func @kernel(%arg0: !tt.ptr<i32> {{.*}}, %arg1: !tt.ptr<i32> {{.*}}, %arg2: !tt.ptr<i32> {{.*}}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
// CHECK: [[C2_I32:%.+]] = arith.constant 2 : i32
// CHECK: [[EMPTY_128X1_I32:%.+]] = tensor.empty() : tensor<128x1xi32>
// CHECK: [[FILLED_C2:%.+]] = linalg.fill ins([[C2_I32]] : i32) outs([[EMPTY_128X1_I32]] : tensor<128x1xi32>) -> tensor<128x1xi32>
// CHECK: [[EMPTY_128_I32:%.+]] = tensor.empty() : tensor<128xi32>
// CHECK: [[RANGE_128:%.+]] = linalg.generic {{.*}} outs([[EMPTY_128_I32]] : tensor<128xi32>) {
// CHECK: ^bb0(%out: i32):
// CHECK: [[IDX0:%.+]] = linalg.index 0 : index
// CHECK: [[I32_IDX0:%.+]] = arith.index_cast [[IDX0]] : index to i32
// CHECK: linalg.yield [[I32_IDX0]] : i32
// CHECK: } -> tensor<128xi32>
// CHECK: [[EMPTY_PTR128:%.+]] = tensor.empty() : tensor<128x!tt.ptr<i32>>
// CHECK: [[SPLAT_ARG0:%.+]] = linalg.fill ins(%arg0 : !tt.ptr<i32>) outs([[EMPTY_PTR128]] : tensor<128x!tt.ptr<i32>>) -> tensor<128x!tt.ptr<i32>>
// CHECK: [[ADDPTR_ARG0:%.+]] = linalg.generic {{.*}} ins([[SPLAT_ARG0]], [[RANGE_128]] : tensor<128x!tt.ptr<i32>>, tensor<128xi32>) outs([[SPLAT_ARG0]] : tensor<128x!tt.ptr<i32>>) {
// CHECK: ^bb0([[IN_PTR:%.+]]: !tt.ptr<i32>, [[IN_I32:%.+]]: i32, [[OUT_PTR:%.+]]: !tt.ptr<i32>):
// CHECK: [[NEW_PTR0:%.+]] = tt.addptr [[IN_PTR]], [[IN_I32]] : !tt.ptr<i32>, i32
// CHECK: linalg.yield [[NEW_PTR0]] : !tt.ptr<i32>
// CHECK: } -> tensor<128x!tt.ptr<i32>>
// CHECK: [[LOADED_ARG0:%.+]] = tt.load [[ADDPTR_ARG0]] : tensor<128x!tt.ptr<i32>>
// CHECK: [[EMPTY_PTR128_1:%.+]] = tensor.empty() : tensor<128x!tt.ptr<i32>>
// CHECK: [[SPLAT_ARG1:%.+]] = linalg.fill ins(%arg1 : !tt.ptr<i32>) outs([[EMPTY_PTR128_1]] : tensor<128x!tt.ptr<i32>>) -> tensor<128x!tt.ptr<i32>>
// CHECK: [[ADDPTR_ARG1:%.+]] = linalg.generic {{.*}} ins([[SPLAT_ARG1]], [[RANGE_128]] : tensor<128x!tt.ptr<i32>>, tensor<128xi32>) outs([[SPLAT_ARG1]] : tensor<128x!tt.ptr<i32>>) {
// CHECK: ^bb0([[IN_PTR:%.+]]: !tt.ptr<i32>, [[IN_I32:%.+]]: i32, [[OUT_PTR:%.+]]: !tt.ptr<i32>):
// CHECK: [[NEW_PTR1:%.+]] = tt.addptr [[IN_PTR]], [[IN_I32]] : !tt.ptr<i32>, i32
// CHECK: linalg.yield [[NEW_PTR1]] : !tt.ptr<i32>
// CHECK: } -> tensor<128x!tt.ptr<i32>>
// CHECK: [[LOADED_ARG1:%.+]] = tt.load [[ADDPTR_ARG1]] : tensor<128x!tt.ptr<i32>>
// CHECK: [[EMPTY_JOIN:%.+]] = tensor.empty() : tensor<128x2xi32>
// CHECK: [[INSERTED_SLICE0:%.+]] = tensor.insert_slice [[LOADED_ARG0]] into [[EMPTY_JOIN]]{{\[}}0, 0{{\]}} [128, 1] [1, 1] : tensor<128xi32> into tensor<128x2xi32>
// CHECK: [[INSERTED_SLICE1:%.+]] = tensor.insert_slice [[LOADED_ARG1]] into [[INSERTED_SLICE0]]{{\[}}0, 1{{\]}} [128, 1] [1, 1] : tensor<128xi32> into tensor<128x2xi32>
// CHECK: [[EXPANDED_RANGE:%.+]] = tensor.expand_shape
// CHECK: [[MULI_RESULT:%.+]] = linalg.generic {{.*}} ins([[EXPANDED_RANGE]], [[FILLED_C2]] : tensor<128x1xi32>, tensor<128x1xi32>) outs([[EXPANDED_RANGE:%.+]] : tensor<128x1xi32>) {
// CHECK: ^bb0([[IN_I32_0:%.+]]: i32, [[IN_I32_1:%.+]]: i32, %out: i32):
// CHECK: [[MUL_RESULT:%.+]] = arith.muli [[IN_I32_0]], [[IN_I32_1]] : i32
// CHECK: linalg.yield [[MUL_RESULT]] : i32
// CHECK: } -> tensor<128x1xi32>
// CHECK: [[EMPTY_PTR128X1:%.+]] = tensor.empty() : tensor<128x1x!tt.ptr<i32>>
// CHECK: [[SPLAT_ARG2:%.+]] = linalg.fill ins(%arg2 : !tt.ptr<i32>) outs([[EMPTY_PTR128X1]] : tensor<128x1x!tt.ptr<i32>>) -> tensor<128x1x!tt.ptr<i32>>
// CHECK: [[ADDPTR_ARG2:%.+]] = linalg.generic {{.*}} ins([[SPLAT_ARG2]], [[MULI_RESULT]] : tensor<128x1x!tt.ptr<i32>>, tensor<128x1xi32>) outs([[SPLAT_ARG2]] : tensor<128x1x!tt.ptr<i32>>) {
// CHECK: ^bb0([[IN_PTR:%.+]]: !tt.ptr<i32>, [[IN_I32:%.+]]: i32, [[OUT_PTR:%.+]]: !tt.ptr<i32>):
// CHECK: [[NEW_PTR2:%.+]] = tt.addptr [[IN_PTR]], [[IN_I32]] : !tt.ptr<i32>, i32
// CHECK: linalg.yield [[NEW_PTR2]] : !tt.ptr<i32>
// CHECK: } -> tensor<128x1x!tt.ptr<i32>>
// CHECK: [[EMPTY_RANGE2:%.+]] = tensor.empty() : tensor<2xi32>
// CHECK: [[RANGE_2:%.+]] = linalg.generic {{.*}} outs([[EMPTY_RANGE2]] : tensor<2xi32>) {
// CHECK: ^bb0(%out: i32):
// CHECK: [[IDX1:%.+]] = linalg.index 0 : index
// CHECK: [[I32_IDX1:%.+]] = arith.index_cast [[IDX1]] : index to i32
// CHECK: linalg.yield [[I32_IDX1]] : i32
// CHECK: } -> tensor<2xi32>
// CHECK: [[EXPANDED_RANGE2:%.+]] = tensor.expand_shape [[RANGE_2]]
// CHECK: [[EMPTY_PTR128X2:%.+]] = tensor.empty() : tensor<128x2x!tt.ptr<i32>>
// CHECK: [[BROADCASTED_PTR:%.+]] = linalg.generic {{.*}} ins([[ADDPTR_ARG2]] : tensor<128x1x!tt.ptr<i32>>) outs([[EMPTY_PTR128X2]] : tensor<128x2x!tt.ptr<i32>>) attrs = {{.*}} {
// CHECK: ^bb0([[IN_PTR:%.+]]: !tt.ptr<i32>, [[OUT_PTR:%.+]]: !tt.ptr<i32>):
// CHECK: linalg.yield [[IN_PTR]] : !tt.ptr<i32>
// CHECK: } -> tensor<128x2x!tt.ptr<i32>>
// CHECK: [[EMPTY_I32_128X2:%.+]] = tensor.empty() : tensor<128x2xi32>
// CHECK: [[BROADCASTED_I32:%.+]] = linalg.generic {{.*}} ins([[EXPANDED_RANGE2]] : tensor<1x2xi32>) outs([[EMPTY_I32_128X2]] : tensor<128x2xi32>) attrs = {{.*}} {
// CHECK: ^bb0([[IN_I32:%.+]]: i32, [[OUT_I32:%.+]]: i32):
// CHECK: linalg.yield [[IN_I32]] : i32
// CHECK: } -> tensor<128x2xi32>
// CHECK: [[ADDPTR_FINAL:%.+]] = linalg.generic {{.*}} ins([[BROADCASTED_PTR]], [[BROADCASTED_I32]] : tensor<128x2x!tt.ptr<i32>>, tensor<128x2xi32>) outs([[BROADCASTED_PTR]] : tensor<128x2x!tt.ptr<i32>>) {
// CHECK: ^bb0([[IN_PTR:%.+]]: !tt.ptr<i32>, [[IN_I32:%.+]]: i32, [[OUT_PTR:%.+]]: !tt.ptr<i32>):
// CHECK: [[FINAL_PTR:%.+]] = tt.addptr [[IN_PTR]], [[IN_I32]] : !tt.ptr<i32>, i32
// CHECK: linalg.yield [[FINAL_PTR]] : !tt.ptr<i32>
// CHECK: } -> tensor<128x2x!tt.ptr<i32>>
// CHECK: tt.store [[ADDPTR_FINAL]], [[INSERTED_SLICE1]] : tensor<128x2x!tt.ptr<i32>>
67 changes: 67 additions & 0 deletions test/Conversion/TritonArithToLinalg/split.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// RUN: triton-shared-opt --triton-arith-to-linalg %s | FileCheck %s

module {
tt.func public @kernel(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
%1 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<256x!tt.ptr<i32>>
%2 = tt.addptr %1, %0 : tensor<256x!tt.ptr<i32>>, tensor<256xi32>
%3 = tt.load %2 : tensor<256x!tt.ptr<i32>>
%4 = tt.reshape %3 {allow_reorder = false} : tensor<256xi32> -> tensor<128x2xi32>
%outLHS, %outRHS = tt.split %4 : tensor<128x2xi32> -> tensor<128xi32>
%5 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
%6 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<128x!tt.ptr<i32>>
%7 = tt.addptr %6, %5 : tensor<128x!tt.ptr<i32>>, tensor<128xi32>
tt.store %7, %outLHS : tensor<128x!tt.ptr<i32>>
%8 = tt.splat %arg2 : !tt.ptr<i32> -> tensor<128x!tt.ptr<i32>>
%9 = tt.addptr %8, %5 : tensor<128x!tt.ptr<i32>>, tensor<128xi32>
tt.store %9, %outRHS : tensor<128x!tt.ptr<i32>>
tt.return
}
}

// CHECK: func.func @kernel(%arg0: !tt.ptr<i32> {{.*}}, %arg1: !tt.ptr<i32> {{.*}}, %arg2: !tt.ptr<i32> {{.*}}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
// CHECK: [[CST:%.+]] = arith.constant dense<[128, 2]> : tensor<2xi64>
// CHECK: [[EMPTY256:%.+]] = tensor.empty() : tensor<256xi32>
// CHECK: [[RANGE256:%.+]] = linalg.generic {{.*}} outs([[EMPTY256]] : tensor<256xi32>) {
// CHECK: ^bb0(%out: i32):
// CHECK: [[IDX0:%.+]] = linalg.index 0 : index
// CHECK: [[I32_0:%.+]] = arith.index_cast [[IDX0]] : index to i32
// CHECK: linalg.yield [[I32_0]] : i32
// CHECK: } -> tensor<256xi32>
// CHECK: [[EMPTY_PTR256:%.+]] = tensor.empty() : tensor<256x!tt.ptr<i32>>
// CHECK: [[SPLAT_ARG0:%.+]] = linalg.fill ins(%arg0 : !tt.ptr<i32>) outs([[EMPTY_PTR256]] : tensor<256x!tt.ptr<i32>>) -> tensor<256x!tt.ptr<i32>>
// CHECK: [[ADDPTR256:%.+]] = linalg.generic {{.*}} ins([[SPLAT_ARG0]], [[RANGE256]] : tensor<256x!tt.ptr<i32>>, tensor<256xi32>) outs([[SPLAT_ARG0]] : tensor<256x!tt.ptr<i32>>) {
// CHECK: ^bb0([[IN_PTR:%.+]]: !tt.ptr<i32>, [[IN_I32:%.+]]: i32, [[OUT_PTR:%.+]]: !tt.ptr<i32>):
// CHECK: [[NEW_PTR:%.+]] = tt.addptr [[IN_PTR]], [[IN_I32]] : !tt.ptr<i32>, i32
// CHECK: linalg.yield [[NEW_PTR]] : !tt.ptr<i32>
// CHECK: } -> tensor<256x!tt.ptr<i32>>
// CHECK: [[LOADED256:%.+]] = tt.load [[ADDPTR256]] : tensor<256x!tt.ptr<i32>>
// CHECK: [[RESHAPED:%.+]] = tensor.reshape [[LOADED256]]([[CST]]) : (tensor<256xi32>, tensor<2xi64>) -> tensor<128x2xi32>
// CHECK: [[SLICE_LHS:%.+]] = tensor.extract_slice [[RESHAPED]]{{\[}}0, 0{{\]}} [128, 1] [1, 1] : tensor<128x2xi32> to tensor<128xi32>
// CHECK: [[SLICE_RHS:%.+]] = tensor.extract_slice [[RESHAPED]]{{\[}}0, 1{{\]}} [128, 1] [1, 1] : tensor<128x2xi32> to tensor<128xi32>
// CHECK: [[EMPTY128:%.+]] = tensor.empty() : tensor<128xi32>
// CHECK: [[RANGE128:%.+]] = linalg.generic {{.*}} outs([[EMPTY128]] : tensor<128xi32>) {
// CHECK: ^bb0(%out: i32):
// CHECK: [[IDX1:%.+]] = linalg.index 0 : index
// CHECK: [[I32_1:%.+]] = arith.index_cast [[IDX1]] : index to i32
// CHECK: linalg.yield [[I32_1]] : i32
// CHECK: } -> tensor<128xi32>
// CHECK: [[EMPTY_PTR128_1:%.+]] = tensor.empty() : tensor<128x!tt.ptr<i32>>
// CHECK: [[SPLAT_ARG1:%.+]] = linalg.fill ins(%arg1 : !tt.ptr<i32>) outs([[EMPTY_PTR128_1]] : tensor<128x!tt.ptr<i32>>) -> tensor<128x!tt.ptr<i32>>
// CHECK: [[ADDPTR128_1:%.+]] = linalg.generic {{.*}} ins([[SPLAT_ARG1]], [[RANGE128]] : tensor<128x!tt.ptr<i32>>, tensor<128xi32>) outs([[SPLAT_ARG1]] : tensor<128x!tt.ptr<i32>>) {
// CHECK: ^bb0([[IN_PTR:%.+]]: !tt.ptr<i32>, [[IN_I32:%.+]]: i32, [[OUT_PTR:%.+]]: !tt.ptr<i32>):
// CHECK: [[NEW_PTR1:%.+]] = tt.addptr [[IN_PTR]], [[IN_I32]] : !tt.ptr<i32>, i32
// CHECK: linalg.yield [[NEW_PTR1]] : !tt.ptr<i32>
// CHECK: } -> tensor<128x!tt.ptr<i32>>
// CHECK: tt.store [[ADDPTR128_1]], [[SLICE_LHS]] : tensor<128x!tt.ptr<i32>>
// CHECK: [[EMPTY_PTR128_2:%.+]] = tensor.empty() : tensor<128x!tt.ptr<i32>>
// CHECK: [[SPLAT_ARG2:%.+]] = linalg.fill ins(%arg2 : !tt.ptr<i32>) outs([[EMPTY_PTR128_2]] : tensor<128x!tt.ptr<i32>>) -> tensor<128x!tt.ptr<i32>>
// CHECK: [[ADDPTR128_2:%.+]] = linalg.generic {{.*}} ins([[SPLAT_ARG2]], [[RANGE128]] : tensor<128x!tt.ptr<i32>>, tensor<128xi32>) outs([[SPLAT_ARG2]] : tensor<128x!tt.ptr<i32>>) {
// CHECK: ^bb0([[IN_PTR:%.+]]: !tt.ptr<i32>, [[IN_I32:%.+]]: i32, [[OUT_PTR:%.+]]: !tt.ptr<i32>):
// CHECK: [[NEW_PTR2:%.+]] = tt.addptr [[IN_PTR]], [[IN_I32]] : !tt.ptr<i32>, i32
// CHECK: linalg.yield [[NEW_PTR2]] : !tt.ptr<i32>
// CHECK: } -> tensor<128x!tt.ptr<i32>>
// CHECK: tt.store [[ADDPTR128_2]], [[SLICE_RHS]] : tensor<128x!tt.ptr<i32>>
// CHECK: return
// CHECK: }
// CHECK: }
Loading

0 comments on commit c9cf66b

Please sign in to comment.