From 177a624668f58adaa231d26bc5f88f7f157d3b8b Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Tue, 15 Oct 2024 13:09:57 -0400 Subject: [PATCH] Support tensor of indices as loop iter-arg in sequences of pointer arithmetic (#180) This PR adds support for tensor of indices that are updated in each loop iteration while also being used in pointer arithmetic sequences. ## Approach Similarly to the pointer types, in the PtrAnalysis pre-pass, we prematurely generate the `tts.get_structured_state` ops for tensor of integers. The important note here is we do not need to know whether these ops will eventually be used in a pointer arithmetic sequence. Any values that are not used in a pointer arithmetic sequence will be removed later in the process. This approach can easily be extended to other kinds of values that might be used in pointer arithmetic sequences. At a high level, `tts.get_structured_state` can always be used to "wrap" a triton value. This op returns two kinds of values: the first value is always of the same type as the wrapped value, while the remaining values expose the important fields in `PtrState` that are necessary for codegen in scf.for. The first return value of `tts.get_structured_state` is always an SSA value of the same type as the original value; users of the original triton value will then use this first return value from `tts.get_structured_state` instead. With this approach, even if the original triton value ends up not being used in a pointer arithmetic sequence, it is very easy to revert the IR to the original form by simply deleting the `tts.get_structured_state` op and forwarding the original triton value to its users again. The other return values then expose the important fields in PtrState that are necessary to generate the code in loops (offsets and strides). Within a loop, for every wrapped triton value returned by a `tts.get_structured_state` op at index `i`, we can always get the corresponding offsets in each loop iteration at index `i + 1` and strides at index `i + 2`. ## Changes + Updated the pre-pass to insert `tts.get_structured_state` ops that wrap tensor of indices + With the introduction of tensor of indices in loops, we now have to manually visit the `tts.get_structured_state` ops to generate the ops for updating PtrState. We previously did not have to do this because triton pointers always have a `tt.addptr` at the end of each loop, right before yielding the values, which always triggers the process for generating the state-update ops + Logic for determining whether a loop iter-arg should have its PtrState updated is improved. We do a BFS-like scan starting from the return values of `tts.get_structured_state` ops to determine if an iter-arg originates from a value that may need its PtrState populated + Preliminary support for mask sequences being updated in a loop; this is a bit of a hack and will need more robust implementation if these use cases appear more frequently. + Add tests for various scenarios --- include/triton-shared/Analysis/MaskAnalysis.h | 4 + .../AnalysisStructured/PtrAnalysis.h | 9 +- .../Conversion/TritonToStructured/Passes.td | 4 +- .../IR/TritonStructuredDialect.td | 6 +- lib/Analysis/MaskAnalysis.cpp | 72 +++++++- lib/AnalysisStructured/PtrAnalysis.cpp | 170 +++++++++++++++--- .../TritonToStructuredPass.cpp | 58 ++++-- .../IR/TritonStructuredOps.cpp | 34 ++-- python/examples/test_tensor_index_iterargs.py | 149 +++++++++++++++ ...ensor_indices_loop_iterarg_with_masks.mlir | 79 ++++++++ .../tensor_indices_loop_iterargs_nested.mlir | 126 +++++++++++++ ...ensor_indices_loop_iterarg_with_masks.mlir | 51 ++++++ .../tensor_indices_loop_iterargs_nested.mlir | 77 ++++++++ ...oop_iterargs_not_used_ptranalysis_e2e.mlir | 43 +++++ ...iterargs_not_used_ptranalysis_prepass.mlir | 44 +++++ .../wraparound_unsupported_add_offset.mlir | 45 +++-- 16 files changed, 880 insertions(+), 91 deletions(-) create mode 100644 python/examples/test_tensor_index_iterargs.py create mode 100644 test/Conversion/TritonToLinalg/tensor_indices_loop_iterarg_with_masks.mlir create mode 100644 test/Conversion/TritonToLinalg/tensor_indices_loop_iterargs_nested.mlir create mode 100644 test/Conversion/TritonToStructured/tensor_indices_loop_iterarg_with_masks.mlir create mode 100644 test/Conversion/TritonToStructured/tensor_indices_loop_iterargs_nested.mlir create mode 100644 test/Conversion/TritonToStructured/tensor_indices_loop_iterargs_not_used_ptranalysis_e2e.mlir create mode 100644 test/Conversion/TritonToStructured/tensor_indices_loop_iterargs_not_used_ptranalysis_prepass.mlir diff --git a/include/triton-shared/Analysis/MaskAnalysis.h b/include/triton-shared/Analysis/MaskAnalysis.h index 3f32f4c8..6d67112f 100644 --- a/include/triton-shared/Analysis/MaskAnalysis.h +++ b/include/triton-shared/Analysis/MaskAnalysis.h @@ -14,6 +14,7 @@ #include "mlir/Support/LogicalResult.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/Support/LogicalResult.h" #include @@ -137,6 +138,9 @@ struct MaskState { // dimension that contains the range. LogicalResult parseExpandDims(triton::ExpandDimsOp expandDimsOp, const Location loc, OpBuilder &builder); + + LogicalResult parseLoopIterArg(Value v, const Location loc, + OpBuilder &builder); }; } // namespace triton diff --git a/include/triton-shared/AnalysisStructured/PtrAnalysis.h b/include/triton-shared/AnalysisStructured/PtrAnalysis.h index 1185184a..d104cd77 100644 --- a/include/triton-shared/AnalysisStructured/PtrAnalysis.h +++ b/include/triton-shared/AnalysisStructured/PtrAnalysis.h @@ -12,6 +12,8 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -90,11 +92,10 @@ class PtrAnalysis { scf::ForOp forOp, size_t ptrArgIndex, const PtrState &state, llvm::function_ref getReplacementVal); -public: - using IndexMapSet = std::map>; + DenseSet maybeStructuredArgs; - IndexMapSet levelToBlockArgIndex; - int level = 0; +public: + void initializeMaybeStructuredArgs(Operation *op); llvm::SmallDenseMap knownPtrs; diff --git a/include/triton-shared/Conversion/TritonToStructured/Passes.td b/include/triton-shared/Conversion/TritonToStructured/Passes.td index 65798de5..89488d64 100644 --- a/include/triton-shared/Conversion/TritonToStructured/Passes.td +++ b/include/triton-shared/Conversion/TritonToStructured/Passes.td @@ -8,7 +8,9 @@ def TritonToStructured : Pass<"triton-to-structured", "mlir::ModuleOp"> { let constructor = "triton::createTritonToStructuredPass()"; let options = [ Option<"runPrepassOnly", "run-prepass-only", "bool", /*default*/"false", - "Only run the pre-processing pass which inserts tts.get_structured_state ops used in scf.for"> + "Only run the pre-processing pass which inserts tts.get_structured_state ops used in scf.for">, + Option<"skipPrepass", "skip-prepass", "bool", /*default*/"false", + "Skip the prepass"> ]; } diff --git a/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td b/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td index 743648ce..c0f89bfc 100644 --- a/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td +++ b/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td @@ -126,11 +126,11 @@ def TTS_GetStructuredStateOp : TTS_Op<"get_structured_state", [AttrSizedResultSe let summary = "Placeholder for the structured pointer states computed during PtrAnalysis."; let description = "Used to pass the offsets and strides to scf.for op to simplify IR rewrites."; - let arguments = (ins TT_PtrLike:$ptr); - let results = (outs TT_PtrLike:$structuredPtr, Variadic:$offsets, Variadic:$strides); + let arguments = (ins AnyTypeOf<[TT_PtrLike, I32Tensor]>:$input); + let results = (outs AnyTypeOf<[TT_PtrLike, I32Tensor]>:$structured, Variadic:$offsets, Variadic:$strides); let builders = [ - OpBuilder<(ins "Value":$ptr)>, + OpBuilder<(ins "Value":$input)>, ]; let extraClassDeclaration = [{ diff --git a/lib/Analysis/MaskAnalysis.cpp b/lib/Analysis/MaskAnalysis.cpp index 4de7ae5e..5dd4bc74 100644 --- a/lib/Analysis/MaskAnalysis.cpp +++ b/lib/Analysis/MaskAnalysis.cpp @@ -8,12 +8,22 @@ #include "triton-shared/Analysis/MaskAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Support/LogicalResult.h" + #include "triton-shared/Analysis/OpFoldResultUtils.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" +#include + namespace mlir { namespace triton { @@ -38,6 +48,8 @@ LogicalResult MaskState::parse(Value operand, const Location loc, return this->parseSplat(op, loc, builder); } else if (auto op = operand.getDefiningOp()) { return this->parseExpandDims(op, loc, builder); + } else if (!operand.getDefiningOp()) { + return this->parseLoopIterArg(operand, loc, builder); } else if (auto op = operand.getDefiningOp()) { return this->parseExtSI(op, loc, builder); } else { @@ -109,8 +121,8 @@ static memref::SubViewOp createSubview(Value src, Location loc, OpBuilder &b, // + + | // +++++++++++++++++------- // -// If we simply take the subview of `buffer_tmp`, this requires an extra buffer -// to just hold the temporary result. +// If we simply take the subview of `buffer_tmp`, this requires an extra +// buffer to just hold the temporary result. // // So we can subview into block1 and block2 directly. There are 2 cases: // + subview only spans block1 @@ -131,8 +143,8 @@ static memref::SubViewOp createSubview(Value src, Location loc, OpBuilder &b, // Let (row, col1) and (row, col2) be the dimensions of block1 and block2, // respectively. // -// Let (rowFull, colFull), (rowView1, colView1) and (rowView2, colView2) be the -// dimensions of the full subview, sv1, and sv2, respectively. +// Let (rowFull, colFull), (rowView1, colView1) and (rowView2, colView2) be +// the dimensions of the full subview, sv1, and sv2, respectively. // // + colView1 = min(colFull, col1) // + colView2 = colFull - colView1 @@ -342,6 +354,58 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc, return success(); } +LogicalResult MaskState::parseLoopIterArg(Value v, const Location loc, + OpBuilder &builder) { + assert(!v.getDefiningOp()); + + auto forOp = llvm::dyn_cast(v.getParentRegion()->getParentOp()); + + if (!forOp) { + return failure(); + } + + // TODO: This implementation does not work with nested loops + if (forOp->getParentOfType()) { + return failure(); + } + + auto it = llvm::find(forOp.getRegionIterArgs(), v); + if (it == forOp.getRegionIterArgs().end()) { + return failure(); + } + + auto argIndex = std::distance(forOp.getRegionIterArgs().begin(), it); + auto initArg = forOp.getInitArgs()[argIndex]; + if (auto getStateOp = initArg.getDefiningOp()) { + auto tritonValue = getStateOp->getOperand(0); + MaskState lhsState; + if (failed(lhsState.parse(tritonValue, loc, builder))) { + return failure(); + } + + // This is a bit of a hack!! + // + // The offsets and dimensions of a MaskState can now depend on a loop's + // iter-arg. + // + // Because the PtrAnalysis's pre-pass already sets up the offsets, + // we can create a new MaskState for each loop iteration by adding the + // original MaskState with the current iter-arg, which is at `argIndex + + // 1`. + // + // This will not work for nested loop scenarios, which would need a + // more robust implementation. + if (failed(this->addStateScalar( + lhsState, forOp.getRegionIterArgs()[argIndex + 1], loc, builder))) { + return failure(); + } + + return success(); + } + + return failure(); +} + LogicalResult MaskState::parseMakeRange(triton::MakeRangeOp rangeOp, const Location loc, OpBuilder &builder) { diff --git a/lib/AnalysisStructured/PtrAnalysis.cpp b/lib/AnalysisStructured/PtrAnalysis.cpp index 967e3578..e3a2deea 100644 --- a/lib/AnalysisStructured/PtrAnalysis.cpp +++ b/lib/AnalysisStructured/PtrAnalysis.cpp @@ -11,8 +11,10 @@ #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/IR/Visitors.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "triton-shared/Analysis/MaskAnalysis.h" #include "triton-shared/Analysis/OpFoldResultUtils.h" @@ -28,9 +30,12 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" #include #include #include +#include +#include #include #define DEBUG_TYPE "triton-ptr-analysis" @@ -566,6 +571,7 @@ LogicalResult PtrAnalysis::visitOperandAddptr(triton::AddPtrOp addptrOp, PtrState ptrState; if (visitOperand(addptrOp.getPtr(), ptrState, addptrOp.getLoc(), builder) .failed()) { + // assert(0); return failure(); } @@ -760,10 +766,13 @@ LogicalResult PtrAnalysis::visitOperand(Value operand, PtrState &state, } else if (auto op = operand.getDefiningOp()) { return visitOperandForOp(op, operand, state, loc, builder); } else if (!operand.getDefiningOp()) { + if (!knownPtrs.contains(operand)) { + return failure(); + } + // This operand must be an iter-arg of an inner-loop in a multiple-level // nested loop, which means its PtrState must have already been populated // during rewriteForOp of the parent loop. - assert(knownPtrs.contains(operand)); state = knownPtrs[operand]; return success(); } else { @@ -952,7 +961,6 @@ FailureOr PtrAnalysis::getLoopIterArgPtrState(scf::ForOp forOp, FailureOr PtrAnalysis::getLoopResultPtrState(scf::ForOp forOp, size_t index) { - auto state = getLoopInitArgPtrState(forOp, index); if (failed(state)) { return failure(); @@ -964,22 +972,25 @@ FailureOr PtrAnalysis::getLoopResultPtrState(scf::ForOp forOp, } LogicalResult PtrAnalysis::rewriteForOp(scf::ForOp op) { - for (auto [i, arg] : llvm::enumerate(op.getInitArgs())) { - if (!isPointerType(arg.getType())) { + for (auto [i, arg] : llvm::enumerate(op.getRegionIterArgs())) { + if (!maybeStructuredArgs.contains(arg)) { continue; } auto state = getLoopIterArgPtrState(op, i); if (failed(state)) { - op.emitError( + // Because the maybeStructuredArgs may contain values that are not + // considered structured by PtrAnalysis, failing to retrieve the PtrState + // should not fail the rewrite process. + // We emit an error for diagnostics and debugging purposes. + op->emitWarning( "Rewrite for-op failed. Could not find PtrState for iter-arg index " + std::to_string(i)); - return failure(); + continue; } // Save the current init arg's PtrState - auto key = op.getRegionIterArgs()[i]; - knownPtrs[key] = state.value(); + knownPtrs[arg] = state.value(); // For tensors of pointers, create a tts.make_tptr at the beginning of the // loop body that correspond to this region iter arg. In case it is used @@ -1004,10 +1015,14 @@ LogicalResult PtrAnalysis::rewriteForOp(scf::ForOp op) { // beginning of the loop body. We don't lower tt.load and tt.store on // scalars in this pass; pointer arithmetics can also just use the // original pointer. - if (state->getRank() != 0) { - OpBuilder builder(op.getRegion()); - auto maketptrOp = state->createTTSMakeTensorPtrOp(builder, op.getLoc()); - ptrMap.map(key, maketptrOp.getResult()); + // Note that there can be tensor of indices in iter-arg, so we only create + // the make_tensor_ptr op when the arg is of pointer type. + if (isPointerType(arg.getType())) { + if (state->getRank() != 0) { + OpBuilder builder(op.getRegion()); + auto maketptrOp = state->createTTSMakeTensorPtrOp(builder, op.getLoc()); + ptrMap.map(arg, maketptrOp.getResult()); + } } } @@ -1023,23 +1038,23 @@ LogicalResult PtrAnalysis::rewriteForOp(scf::ForOp op) { LogicalResult PtrAnalysis::rewriteGetStructuredStateOp(tts::GetStructuredStateOp op) { - auto tritonPtr = op->getOperand(0); + auto tritonValue = op->getOperand(0); - // If this pointer isn't known, it means PtrAnalysis has failed to analyze - // this pointer. In such cases, simply remap all uses of the - // structured-pointer back to its original pointer. - if (!knownPtrs.contains(tritonPtr)) { + // If this triton value isn't known, it means PtrAnalysis has failed to + // analyze this pointer. In such cases, simply remap all uses of the + // structured value back to its original triton value. + if (!knownPtrs.contains(tritonValue)) { op.emitRemark( "Rewrite GetStructuredStateOp failed. Could not find PtrState."); - op.getResult(0).replaceAllUsesWith(tritonPtr); + op.getResult(0).replaceAllUsesWith(tritonValue); return failure(); } - tts::PtrState state = knownPtrs[tritonPtr]; - assert(ptrMap.contains(tritonPtr)); - Value remappedPtr = ptrMap.lookup(tritonPtr); + tts::PtrState state = knownPtrs[tritonValue]; + Value remappedValue = + ptrMap.contains(tritonValue) ? ptrMap.lookup(tritonValue) : tritonValue; - SmallVector replacements{remappedPtr}; + SmallVector replacements{remappedValue}; if (state.getRank() == 0) { // For scalar pointers, the scalar contains the offset and is the only @@ -1131,6 +1146,86 @@ LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op) { return success(); } +// Structured values from the TritonStructuredDialect have offsets and strides +// that might change in each loop iteration and hence will appear in an scf.for +// iter-args like so: +// +// %structured, %offsets, %strides = tts.get_structured_state +// scf.for (%arg0 = %structured, %arg1 = %offsets, %arg2 = %strides) { +// %a = %arg0 + 1 +// %b = %b + 2 +// scf.for (%arg1 = %b) { +// ... +// } +// } +// +// In `rewriteForOp`, we have to recognize such structured values in order to +// rewrite their PtrState accordingly. Previously, only values of Pointer-like +// type (e.g.: tensor> or tt.ptr>), so detecting these values +// is as easy as checking the type. +// +// Now, tensor of indices could also appear in a loop's iter-arg. To reliably +// detect all such cases, we perform a BFS-like traversal of the IR where the +// sources are the results of `tts.get_structured_state`. All values that +// originate from the results of `tts.get_structured_state` are consider +// "maybeStructured". If a loop's iter-arg is considered "maybeStructured", we +// must set up their PtrState during `rewriteForOp`. +void PtrAnalysis::initializeMaybeStructuredArgs(Operation *op) { + std::queue q; + DenseSet visited; + + op->walk([&q, &visited](tts::GetStructuredStateOp getStateOp) { + Value value = getStateOp->getResult(0); + visited.insert(value); + q.push(value); + }); + + while (!q.empty()) { + auto v = q.front(); + q.pop(); + for (auto user : v.getUsers()) { + // scf.for is a special case. We have 2 set of values to consider: + // - iter-args + // - loop results + // for every init arg that originates from a `tts.get_structured_state` + // op, its corresponding iter-arg and loop result will also be considered + // "maybeStructured". + if (auto forOp = dyn_cast(user)) { + auto it = llvm::find(forOp.getInitArgs(), v); + + if (it == forOp.getInitArgs().end()) { + continue; + } + + auto argIndex = std::distance(forOp.getInitArgs().begin(), it); + auto iterArg = forOp.getRegionIterArg(argIndex); + auto tiedLoopRes = forOp.getTiedLoopResult(iterArg); + + SmallVector neighbors{iterArg, tiedLoopRes}; + for (auto neighbor : neighbors) { + maybeStructuredArgs.insert(neighbor); + if (!visited.contains(neighbor)) { + visited.insert(neighbor); + q.push(neighbor); + } + } + + } else { + for (auto res : user->getResults()) { + if (res.getType() != v.getType()) { + continue; + } + maybeStructuredArgs.insert(res); + if (!visited.contains(res)) { + visited.insert(res); + q.push(res); + } + } + } + } + } +} + LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op) { auto ptr = ptrMap.lookupOrNull(op.getPtr()); auto val = op.getValue(); @@ -1222,13 +1317,40 @@ LogicalResult PtrAnalysis::rewriteOp(Operation *rootOp) { .Case([&](auto forOp) { // `rewriteForOp` recursively visits its children, so regardless // whether the rewrite succeeds or not, we need to return "skip" so - // that the the walk does not visit the for-op's child operations the - // second time. + // that the the walk does not visit the for-op's child operations + // the second time. if (rewriteForOp(forOp).failed()) { forOp->emitRemark("PtrAnalysis: Failed to rewrite ForOp"); } return WalkResult::skip(); }) + .Case( + [&](tts::GetStructuredStateOp getStateOp) { + // For tensor of indices potentially being used in pointer + // arithmetic sequence, we need to manually populate the state of + // none already exists. + // This process is necessary because unlike triton pointers in a + // loop which always have a `tt.addptr` that triggers the rewrite + // process which includes generating the ops for updating offsets + // and strides, tensor of indices only have a simple `arith.addi` + // (or other arith ops). + // Without visiting these ops manually, the ops to update the + // offsets and strides would not be generated. + auto tritonValue = getStateOp->getOperand(0); + if (!knownPtrs.contains(tritonValue)) { + PtrState state; + OpBuilder b(getStateOp); + if (succeeded(visitOperand(tritonValue, state, + getStateOp->getLoc(), b))) { + knownPtrs[tritonValue] = state; + } else { + getStateOp->emitRemark("PtrAnalysis: Failed to populate ptr " + "state for tensor of indices"); + } + } + + return WalkResult::skip(); + }) .Default([&](auto) { return WalkResult::advance(); }); }); diff --git a/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp b/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp index 1f55f285..45ff8ab2 100644 --- a/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp +++ b/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp @@ -34,6 +34,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" #include #include @@ -84,7 +85,12 @@ class TritonToStructuredPass converter.addConversion([context](RankedTensorType tensorType, SmallVectorImpl &types) -> std::optional { - if (!isa(tensorType.getElementType())) { + // Important note: + // We only care about tensor of index / int (in addition to pointer type) + // because only values of int and index type can potentially be part of a + // pointer arithmetic sequence. + if (!isa(tensorType.getElementType()) && + !tensorType.getElementType().isIntOrIndex()) { // There's a subtle difference between returning failure() and // std::nullopt. From the documentation: // @@ -235,8 +241,15 @@ class TritonToStructuredPass // init args easier, especially with multiple levels of loops. // // Background: - // If a triton pointer is updated and returned in a scf.for op, it means + // + // PtrAnalysis computes a PtrState for every operand (or triton value) + // involved in a sequence of pointer arithmetic; some examples include: triton + // pointer, offsets (which could be a tensor of indices or just a simple index + // value). + // + // If a triton value is updated and returned in a scf.for op, it means // that we have to carry its offsets and strides in the scf.for's iterargs. + // // Previously, we have to manually rewrite the loops to include the // relevant information from a PtrState which was rather involved and // error-prone; this was also hard to scale up to multiple level of loops @@ -244,39 +257,46 @@ class TritonToStructuredPass // maintain. // // With the introduction of the prepass that inserts - // `tts.get_structured_state`, the return values of these ops, which include a - // triton pointer and its corresponding offsets and strides, will be used as - // "placeholders" into the scf.for's init-args. We leverage standard MLIR - // infrastructure 1->N conversion to perform this rewrite, which helps - // simplify the logic significantly. + // `tts.get_structured_state`. The return values of these ops, which include a + // triton value with its original result type and its corresponding offsets + // and strides, will be used as "placeholders" into the scf.for's init-args. + // We leverage standard MLIR infrastructure 1->N conversion to perform this + // rewrite, which helps simplify the logic significantly. // // After PtrAnalysis finishes, the return values of these // `tts.get_structured_state` ops will be remapped to the correct - // initialization of the pointer's offsets and strides through the pointer's + // initialization of the value's offsets and strides through the value's // computed PtrState. // // Implementation details: // In essence, what we really want to do in the prepass is, for every value - // of triton-pointer-like type (tt.ptr or tensor>), we want to - // create an op `tts.get_structured_state` that takes in the original triton - // pointer value and returns a series of values: + // of triton-pointer-like type (tt.ptr or tensor>) and tensor of + // indices (tensor) which might be used in a sequence of pointer + // arithmetic, we want to create an op `tts.get_structured_state` that takes + // in the original triton value and returns a series of values: // - // {triton_ptr, offset_0, offset_1, ..., stride_0, stride_1,...} + // {triton_value, offset_0, offset_1, ..., stride_0, stride_1,...} // // Applying the above conversion will also mean that any structural ops such // as scf.for and scf.yield that originally takes the triton pointer will - // then take {triton_ptr, offset_0, offset_1, ..., stride_0, stride_1,...}. + // then take {triton_value, offset_0, offset_1, ..., stride_0, stride_1,...}. // // The 1->N type conversion is a perfect fit for this transformation. // Unfortunately, we cannot do this is one pass, because the current 1->N // type conversion implementation for scf.for ops doesn't provide us with a - // way to detect that a type conversion is recursive. So a triton_ptr type - // that gets converted to a {triton_ptr, offset_0, offset_1, ..., stride_0, + // way to detect that a type conversion is recursive. So a triton_value type + // that gets converted to a {triton_value, offset_0, offset_1, ..., stride_0, // stride_1,...} will recursively trigger other conversions. // - // To fix this issue, we have to first convert triton_ptr to - // tuple. + // To fix this issue, we have to first convert triton_value to + // tuple. // Finally, we decompose these tuples into the desired sequence. + // + // Note that even though the type conversion happens for every integer tensor + // appearing in loops' iter-args, this conversion is reversible. If the + // integer tensor isn't used in a pointer arithmetic sequence, + // canonicalization will remove all the `tts.get_structured_state` ops and + // revert the IR back to its original form. LogicalResult runTritonToStructuredPrepass() { if (failed(convertToPointerTupleWithOffsetsAndStrides())) { return failure(); @@ -286,7 +306,7 @@ class TritonToStructuredPass } void runOnOperation() override { - if (failed(runTritonToStructuredPrepass())) { + if (!skipPrepass && failed(runTritonToStructuredPrepass())) { signalPassFailure(); return; } @@ -297,6 +317,8 @@ class TritonToStructuredPass auto moduleOp = getOperation(); mlir::tts::PtrAnalysis ptrAnalysis; + ptrAnalysis.initializeMaybeStructuredArgs(moduleOp); + if (failed(ptrAnalysis.rewriteOp(moduleOp))) { moduleOp->emitWarning("PtrAnalysis failed"); } diff --git a/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp b/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp index 76e51661..cf55d834 100644 --- a/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp +++ b/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp @@ -11,6 +11,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/LogicalResult.h" #include #include #include @@ -94,9 +95,8 @@ void StoreOp::build(OpBuilder &b, OperationState &state, Value ptr, Value value, } LogicalResult GetStructuredStateOp::verify() { - return success(); auto expectedOffsetAndStrideTypes = - getOffsetAndStrideTypes(getContext(), getStructuredPtr().getType()); + getOffsetAndStrideTypes(getContext(), getInput().getType()); if (!expectedOffsetAndStrideTypes.has_value()) { return failure(); @@ -112,8 +112,8 @@ LogicalResult GetStructuredStateOp::verify() { } void GetStructuredStateOp::build(OpBuilder &b, OperationState &state, - Value ptr) { - auto type = ptr.getType(); + Value val) { + auto type = val.getType(); // Builder cannot fail, so we default to empty offset and stride types. // The invalid op will be rejected by the verifier later. @@ -121,13 +121,12 @@ void GetStructuredStateOp::build(OpBuilder &b, OperationState &state, getOffsetAndStrideTypes(b.getContext(), type) .value_or(std::make_pair(SmallVector{}, SmallVector{})); - build(b, state, ptr.getType(), offsetTypes, strideTypes, ptr); + build(b, state, val.getType(), offsetTypes, strideTypes, val); } std::optional, SmallVector>> -GetStructuredStateOp::getOffsetAndStrideTypes(MLIRContext *context, - Type ptrLikeType) { - auto sizes = getOffsetAndStrideSegmentSizes(ptrLikeType); +GetStructuredStateOp::getOffsetAndStrideTypes(MLIRContext *context, Type type) { + auto sizes = getOffsetAndStrideSegmentSizes(type); if (!sizes.has_value()) { return std::nullopt; } @@ -137,21 +136,28 @@ GetStructuredStateOp::getOffsetAndStrideTypes(MLIRContext *context, } std::optional> -GetStructuredStateOp::getOffsetAndStrideSegmentSizes(Type ptrLikeType) { +GetStructuredStateOp::getOffsetAndStrideSegmentSizes(Type type) { int32_t offsetSegmentSize = 0; int32_t strideSegmentSize = 0; - // Unstructured pointers (tensor>) - if (auto tensorType = llvm::dyn_cast(ptrLikeType)) { - if (auto ptrType = - dyn_cast(tensorType.getElementType())) { + if (auto tensorType = llvm::dyn_cast(type)) { + if (tensorType.getElementType().isIntOrIndex()) { + // Tensors of offsets + // Important note: + // We only care about tensor of index / int (in addition to pointer type) + // because only values of int and index type can potentially be part of a + // pointer arithmetic sequence. + offsetSegmentSize = strideSegmentSize = tensorType.getRank(); + } else if (auto ptrType = + dyn_cast(tensorType.getElementType())) { + // Unstructured pointers (tensor>) // Each tensor of rank k gets k values for its offsets and k values for // its strides, all of which has Index type. offsetSegmentSize = strideSegmentSize = tensorType.getRank(); } } // Block pointers (!tt.ptr> or !tt.ptr) - else if (auto ptrType = llvm::dyn_cast(ptrLikeType)) { + else if (auto ptrType = llvm::dyn_cast(type)) { if (auto tensorType = llvm::dyn_cast(ptrType.getPointeeType())) { // Each tensor of rank k gets k values for its offsets and k values for diff --git a/python/examples/test_tensor_index_iterargs.py b/python/examples/test_tensor_index_iterargs.py new file mode 100644 index 00000000..6bc52d66 --- /dev/null +++ b/python/examples/test_tensor_index_iterargs.py @@ -0,0 +1,149 @@ +import torch + +import triton +import triton.language as tl + +from triton.backends.triton_shared.driver import CPUDriver + +def test_tensor_indices_nested_with_mask(device): + @triton.jit + def addptr_with_masks(in0, out0, mask_bound): + offs = tl.arange(0, 4) + out_offs = tl.arange(0, 4) + # We're loading 16 elements here, the bound is set to 14 so that + # the mask only applies to the last iteration's load + # TODO: The current mask implementation in triton-shared does not seem + # to work when the mask applies to the entire tensor load, perhaps + # the lowerings for subviews with 0-dimensions do not work? + for i in range(0, 4): + mask = offs < mask_bound + a = tl.load(in0 + offs, mask=mask, other=-11) + tl.store(out0 + out_offs, a) + offs += 4 + out_offs += 4 + + + SIZE = 17 + input = torch.arange(0, SIZE, device=device, dtype=torch.int32) + output = torch.full((SIZE,), -1, device=device, dtype=torch.int32) + + if device == 'cpu': + triton.runtime.driver.set_active(CPUDriver()) + + grid = lambda meta: (1,) + + print(output) + addptr_with_masks[grid](input, output, 14) + expected_output = torch.tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + -11, -11, -1], dtype=torch.int32, device=device) + torch.testing.assert_close(output, expected_output) + print(input) + print(output) + + +def test_tensor_indices_nested(device): + @triton.jit + def tensor_indices_nested(in0, out0): + offs = tl.arange(0, 4) + out_offs = tl.arange(0, 4) + for i in range(0, 2): + offs += i * 2 + a = tl.load(in0 + offs) + tl.store(out0 + out_offs, a) + offs += 4 + out_offs += 4 + for j in range(0, 3): + offs += j * 3 + a = tl.load(in0 + offs) + tl.store(out0 + out_offs, a) + offs += 4 + out_offs += 4 + + SIZE = 64 + input = torch.arange(0, SIZE, device=device, dtype=torch.int32) + output = torch.full((SIZE,), -1, device=device, dtype=torch.int32) + + if device == 'cpu': + triton.runtime.driver.set_active(CPUDriver()) + + grid = lambda meta: (1,) + + print(output) + tensor_indices_nested[grid](input, output) + expected_output = torch.tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 11, 12, 13, 14, 21, 22, 23, 24, 27, 28, + 29, 30, 31, 32, 33, 34, 38, 39, 40, 41, 48, 49, 50, 51, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], device=device, + dtype=torch.int32) + torch.testing.assert_close(output, expected_output) + print(input) + print(output) + +def test_integer_tensor(device): + @triton.jit + def test_1(out0): + offs = tl.arange(0, 4) + out_offs = tl.arange(0, 4) + for i in range(0, 2): + tl.store(out0 + out_offs, offs) + out_offs += 4 + offs += 4 + + + SIZE = 8 + input = torch.arange(0, SIZE, device=device, dtype=torch.int32) + output = torch.full((SIZE,), -1, device=device, dtype=torch.int32) + + if device == 'cpu': + triton.runtime.driver.set_active(CPUDriver()) + + grid = lambda meta: (1,) + + print(output) + test_1[grid](output) + print(input) + print(output) + torch.testing.assert_close(input, output) + src = triton.compiler.ASTSource( + fn=test_1, + signature="*fp32", + ) + ret = triton.compile( + src, + ) + print(ret.asm["ttir"]) + + + +def disabled_test_mask(device): + # TODO: This fails to compile in StructuredToMemref + @triton.jit + def test_1(in0, out0, batch): + offs = 4 + tl.arange(0, 4) + out_offs = tl.arange(0, 4) + a = tl.load(in0 + offs, mask=offs < 0, other=-1) + tl.store(out0 + out_offs, a) + + # TODO: This segfauls in the CPU backend + # Crashes when the batch value will mask off all of the tensors + @triton.jit + def test_2(in0, out0, batch): + offs = 4 + tl.arange(0, 4) + out_offs = tl.arange(0, 4) + a = tl.load(in0 + offs, mask=offs < 0, other=-1) + tl.store(out0 + out_offs, a) + + + SIZE = 8 + input = torch.arange(0, SIZE, device=device, dtype=torch.int32) + output = torch.full((SIZE,), -1, device=device, dtype=torch.int32) + + if device == 'cpu': + triton.runtime.driver.set_active(CPUDriver()) + + grid = lambda meta: (1,) + + print(output) + test_1[grid](input, output, 0) + print(input) + print(output) diff --git a/test/Conversion/TritonToLinalg/tensor_indices_loop_iterarg_with_masks.mlir b/test/Conversion/TritonToLinalg/tensor_indices_loop_iterarg_with_masks.mlir new file mode 100644 index 00000000..e2134d84 --- /dev/null +++ b/test/Conversion/TritonToLinalg/tensor_indices_loop_iterarg_with_masks.mlir @@ -0,0 +1,79 @@ +// RUN: triton-shared-opt --triton-to-linalg-experimental %s | FileCheck %s + +// IR from python/examples/test_tensor_index_iterargs.py +module { + tt.func public @addptr_with_masks(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %cst = arith.constant dense<-1.100000e+01> : tensor<4xf32> + %c1_i32 = arith.constant 1 : i32 + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<4> : tensor<4xi32> + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = tt.splat %arg2 : i32 -> tensor<4xi32> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %4:2 = scf.for %arg3 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg4 = %0, %arg5 = %0) -> (tensor<4xi32>, tensor<4xi32>) : i32 { + %5 = arith.cmpi slt, %arg4, %1 : tensor<4xi32> + %6 = tt.addptr %2, %arg4 : tensor<4x!tt.ptr>, tensor<4xi32> + %7 = tt.load %6, %5, %cst : tensor<4x!tt.ptr> + %8 = tt.addptr %3, %arg5 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %8, %7 : tensor<4x!tt.ptr> + %9 = arith.addi %arg4, %cst_0 : tensor<4xi32> + %10 = arith.addi %arg5, %cst_0 : tensor<4xi32> + scf.yield %9, %10 : tensor<4xi32>, tensor<4xi32> + } + tt.return + } +} + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func.func @addptr_with_masks +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : i32 +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_4_1_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_minus_1_dot_100000_:%.+]] = arith.constant -1.100000e+01 : f32 +// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_4_]] : i32) outs([[VAR_0_]] : tensor<4xi32>) -> tensor<4xi32> +// CHECK-DAG: [[VAR_2_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_0_]] : tensor<4xi32>) { +// CHECK: ^bb0([[IN_0_:%.+]]: i32): +// CHECK: [[VAR_4_:%.+]] = linalg.index 0 : index +// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[VAR_4_]] : index to i32 +// CHECK: linalg.yield [[VAR_5_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK-DAG: [[VAR_3_:%.+]]:4 = scf.for [[VAR_arg9_:%.+]] = [[CST_0_]] to [[CST_4_]] step [[CST_1_]] iter_args([[VAR_arg10_:%.+]] = [[VAR_2_]], [[VAR_arg11_:%.+]] = [[CST_0_1_]], [[VAR_arg12_:%.+]] = [[VAR_2_]], [[VAR_arg13_:%.+]] = [[CST_0_1_]]) -> (tensor<4xi32>, index, tensor<4xi32>, index) : i32 { +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_arg11_]]{{.}}, sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}} : memref<*xf32> to memref<4xf32, strided<[?], offset: ?>> +// CHECK-DAG: [[VAR_4_1_:%.+]] = arith.addi [[VAR_arg11_]], [[CST_4_1_]] : index +// CHECK-DAG: [[VAR_5_1_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK: [[VAR_6_:%.+]] = arith.minsi [[VAR_4_1_]], [[VAR_5_1_]] : index +// CHECK-DAG: [[VAR_7_:%.+]] = arith.subi [[VAR_6_]], [[VAR_arg11_]] : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4xf32> +// CHECK: [[VAR_8_:%.+]] = arith.cmpi slt, [[VAR_7_]], [[CST_4_1_]] : index +// CHECK: scf.if [[VAR_8_]] { +// CHECK: linalg.fill ins([[CST_minus_1_dot_100000_]] : f32) outs([[RES_]] : memref<4xf32>) +// CHECK: } +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_7_]]{{.}} [1] : memref<4xf32, strided<[?], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_7_]]{{.}} [1] : memref<4xf32> to memref> +// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> +// CHECK-DAG: [[VAR_9_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4xf32> +// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg13_]]{{.}}, sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}} : memref<*xf32> to memref<4xf32, strided<[?], offset: ?>> +// CHECK: bufferization.materialize_in_destination [[VAR_9_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<4xf32>, memref<4xf32, strided<[?], offset: ?>>) -> () +// CHECK: [[VAR_10_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_arg10_]], [[VAR_1_]] : tensor<4xi32>, tensor<4xi32>) outs([[VAR_arg10_]] : tensor<4xi32>) { +// CHECK: ^bb0([[IN_1_:%.+]]: i32, [[IN_2_:%.+]]: i32, [[IN_3_:%.+]]: i32): +// CHECK: [[VAR_13_:%.+]] = arith.addi [[IN_1_]], [[IN_2_]] : i32 +// CHECK: linalg.yield [[VAR_13_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK: [[VAR_11_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_arg12_]], [[VAR_1_]] : tensor<4xi32>, tensor<4xi32>) outs([[VAR_arg12_]] : tensor<4xi32>) { +// CHECK: ^bb0([[IN_4_:%.+]]: i32, [[IN_5_:%.+]]: i32, [[IN_6_:%.+]]: i32): +// CHECK: [[VAR_13_1_:%.+]] = arith.addi [[IN_4_]], [[IN_5_]] : i32 +// CHECK: linalg.yield [[VAR_13_1_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK: [[VAR_12_:%.+]] = arith.addi [[VAR_arg13_]], [[CST_4_1_]] : index +// CHECK: scf.yield [[VAR_10_]], [[VAR_4_1_]], [[VAR_11_]], [[VAR_12_]] : tensor<4xi32>, index, tensor<4xi32>, index +// CHECK: } +// CHECK: return +// CHECK: } diff --git a/test/Conversion/TritonToLinalg/tensor_indices_loop_iterargs_nested.mlir b/test/Conversion/TritonToLinalg/tensor_indices_loop_iterargs_nested.mlir new file mode 100644 index 00000000..60b6e5b1 --- /dev/null +++ b/test/Conversion/TritonToLinalg/tensor_indices_loop_iterargs_nested.mlir @@ -0,0 +1,126 @@ +// RUN: triton-shared-opt --triton-to-linalg-experimental %s | FileCheck %s + +// IR from python/examples/test_tensor_index_iterargs.py +module { + tt.func public @tensor_indices_nested(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c3_i32 = arith.constant 3 : i32 + %cst = arith.constant dense<4> : tensor<4xi32> + %c2_i32 = arith.constant 2 : i32 + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %2 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %0, %arg4 = %0) -> (tensor<4xi32>, tensor<4xi32>) : i32 { + %4 = arith.muli %arg2, %c2_i32 : i32 + %5 = tt.splat %4 : i32 -> tensor<4xi32> + %6 = arith.addi %arg3, %5 : tensor<4xi32> + %7 = tt.addptr %1, %6 : tensor<4x!tt.ptr>, tensor<4xi32> + %8 = tt.load %7 : tensor<4x!tt.ptr> + %9 = tt.addptr %2, %arg4 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %9, %8 : tensor<4x!tt.ptr> + %10 = arith.addi %6, %cst : tensor<4xi32> + %11 = arith.addi %arg4, %cst : tensor<4xi32> + %12:2 = scf.for %arg5 = %c0_i32 to %c3_i32 step %c1_i32 iter_args(%arg6 = %10, %arg7 = %11) -> (tensor<4xi32>, tensor<4xi32>) : i32 { + %13 = arith.muli %arg5, %c3_i32 : i32 + %14 = tt.splat %13 : i32 -> tensor<4xi32> + %15 = arith.addi %arg6, %14 : tensor<4xi32> + %16 = tt.addptr %1, %15 : tensor<4x!tt.ptr>, tensor<4xi32> + %17 = tt.load %16 : tensor<4x!tt.ptr> + %18 = tt.addptr %2, %arg7 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %18, %17 : tensor<4x!tt.ptr> + %19 = arith.addi %15, %cst : tensor<4xi32> + %20 = arith.addi %arg7, %cst : tensor<4xi32> + scf.yield %19, %20 : tensor<4xi32>, tensor<4xi32> + } + scf.yield %12#0, %12#1 : tensor<4xi32>, tensor<4xi32> + } + tt.return + } +} + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func.func @tensor_indices_nested +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : i32 +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : i32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_4_1_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_4_]] : i32) outs([[VAR_0_]] : tensor<4xi32>) -> tensor<4xi32> +// CHECK-DAG: [[VAR_2_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_0_]] : tensor<4xi32>) { +// CHECK: ^bb0([[IN_0_:%.+]]: i32): +// CHECK: [[VAR_4_:%.+]] = linalg.index 0 : index +// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[VAR_4_]] : index to i32 +// CHECK: linalg.yield [[VAR_5_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK-DAG: [[VAR_3_:%.+]]:4 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg9_:%.+]] = [[VAR_2_]], [[VAR_arg10_:%.+]] = [[CST_0_1_]], [[VAR_arg11_:%.+]] = [[VAR_2_]], [[VAR_arg12_:%.+]] = [[CST_0_1_]]) -> (tensor<4xi32>, index, tensor<4xi32>, index) : i32 { +// CHECK-DAG: [[VAR_4_1_:%.+]] = arith.muli [[VAR_arg8_]], [[CST_2_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_5_1_:%.+]] = arith.index_cast [[VAR_4_1_]] : i32 to index +// CHECK-DAG: [[VAR_6_:%.+]] = linalg.fill ins([[VAR_4_1_]] : i32) outs([[VAR_0_]] : tensor<4xi32>) -> tensor<4xi32> +// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_arg9_]], [[VAR_6_]] : tensor<4xi32>, tensor<4xi32>) outs([[VAR_arg9_]] : tensor<4xi32>) { +// CHECK: ^bb0([[IN_1_:%.+]]: i32, [[IN_2_:%.+]]: i32, [[IN_3_:%.+]]: i32): +// CHECK: [[VAR_15_:%.+]] = arith.addi [[IN_1_]], [[IN_2_]] : i32 +// CHECK: linalg.yield [[VAR_15_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_arg10_]], [[VAR_5_1_]] : index +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_8_]]{{.}}, sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}} : memref<*xf32> to memref<4xf32, strided<[?], offset: ?>> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4xf32> +// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4xf32, strided<[?], offset: ?>> to memref<4xf32> +// CHECK-DAG: [[VAR_9_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4xf32> +// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg12_]]{{.}}, sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}} : memref<*xf32> to memref<4xf32, strided<[?], offset: ?>> +// CHECK: bufferization.materialize_in_destination [[VAR_9_]] in writable [[VAR_reinterpret_cast_0_]] : (tensor<4xf32>, memref<4xf32, strided<[?], offset: ?>>) -> () +// CHECK: [[VAR_10_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]], [[VAR_1_]] : tensor<4xi32>, tensor<4xi32>) outs([[VAR_7_]] : tensor<4xi32>) { +// CHECK: ^bb0([[IN_4_:%.+]]: i32, [[IN_5_:%.+]]: i32, [[IN_6_:%.+]]: i32): +// CHECK: [[VAR_15_1_:%.+]] = arith.addi [[IN_4_]], [[IN_5_]] : i32 +// CHECK: linalg.yield [[VAR_15_1_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK: [[VAR_11_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_arg11_]], [[VAR_1_]] : tensor<4xi32>, tensor<4xi32>) outs([[VAR_arg11_]] : tensor<4xi32>) { +// CHECK: ^bb0([[IN_7_:%.+]]: i32, [[IN_8_:%.+]]: i32, [[IN_9_:%.+]]: i32): +// CHECK: [[VAR_15_2_:%.+]] = arith.addi [[IN_7_]], [[IN_8_]] : i32 +// CHECK: linalg.yield [[VAR_15_2_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK-DAG: [[VAR_12_:%.+]] = arith.addi [[VAR_8_]], [[CST_4_1_]] : index +// CHECK-DAG: [[VAR_13_:%.+]] = arith.addi [[VAR_arg12_]], [[CST_4_1_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_14_:%.+]]:4 = scf.for [[VAR_arg13_:%.+]] = [[CST_0_]] to [[CST_3_]] step [[CST_1_]] iter_args([[VAR_arg14_:%.+]] = [[VAR_10_]], [[VAR_arg15_:%.+]] = [[VAR_12_]], [[VAR_arg16_:%.+]] = [[VAR_11_]], [[VAR_arg17_:%.+]] = [[VAR_13_]]) -> (tensor<4xi32>, index, tensor<4xi32>, index) : i32 { +// CHECK-DAG: [[VAR_15_3_:%.+]] = arith.muli [[VAR_arg13_]], [[CST_3_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_16_:%.+]] = arith.index_cast [[VAR_15_3_]] : i32 to index +// CHECK-DAG: [[VAR_17_:%.+]] = linalg.fill ins([[VAR_15_3_]] : i32) outs([[VAR_0_]] : tensor<4xi32>) -> tensor<4xi32> +// CHECK: [[VAR_18_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_arg14_]], [[VAR_17_]] : tensor<4xi32>, tensor<4xi32>) outs([[VAR_arg14_]] : tensor<4xi32>) { +// CHECK: ^bb0([[IN_10_:%.+]]: i32, [[IN_11_:%.+]]: i32, [[IN_12_:%.+]]: i32): +// CHECK: [[VAR_25_:%.+]] = arith.addi [[IN_10_]], [[IN_11_]] : i32 +// CHECK: linalg.yield [[VAR_25_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK: [[VAR_19_:%.+]] = arith.addi [[VAR_arg15_]], [[VAR_16_]] : index +// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_19_]]{{.}}, sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}} : memref<*xf32> to memref<4xf32, strided<[?], offset: ?>> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<4xf32> +// CHECK: memref.copy [[VAR_reinterpret_cast_1_]], [[RES_1_]] : memref<4xf32, strided<[?], offset: ?>> to memref<4xf32> +// CHECK-DAG: [[VAR_20_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<4xf32> +// CHECK-DAG: [[VAR_reinterpret_cast_3_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg17_]]{{.}}, sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}} : memref<*xf32> to memref<4xf32, strided<[?], offset: ?>> +// CHECK: bufferization.materialize_in_destination [[VAR_20_]] in writable [[VAR_reinterpret_cast_3_]] : (tensor<4xf32>, memref<4xf32, strided<[?], offset: ?>>) -> () +// CHECK: [[VAR_21_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_18_]], [[VAR_1_]] : tensor<4xi32>, tensor<4xi32>) outs([[VAR_18_]] : tensor<4xi32>) { +// CHECK: ^bb0([[IN_13_:%.+]]: i32, [[IN_14_:%.+]]: i32, [[IN_15_:%.+]]: i32): +// CHECK: [[VAR_25_1_:%.+]] = arith.addi [[IN_13_]], [[IN_14_]] : i32 +// CHECK: linalg.yield [[VAR_25_1_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK: [[VAR_22_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_arg16_]], [[VAR_1_]] : tensor<4xi32>, tensor<4xi32>) outs([[VAR_arg16_]] : tensor<4xi32>) { +// CHECK: ^bb0([[IN_16_:%.+]]: i32, [[IN_17_:%.+]]: i32, [[IN_18_:%.+]]: i32): +// CHECK: [[VAR_25_2_:%.+]] = arith.addi [[IN_16_]], [[IN_17_]] : i32 +// CHECK: linalg.yield [[VAR_25_2_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK-DAG: [[VAR_23_:%.+]] = arith.addi [[VAR_19_]], [[CST_4_1_]] : index +// CHECK-DAG: [[VAR_24_:%.+]] = arith.addi [[VAR_arg17_]], [[CST_4_1_]] : index +// CHECK: scf.yield [[VAR_21_]], [[VAR_23_]], [[VAR_22_]], [[VAR_24_]] : tensor<4xi32>, index, tensor<4xi32>, index +// CHECK: } +// CHECK: scf.yield [[VAR_14_]]#0, [[VAR_14_]]#1, [[VAR_14_]]#2, [[VAR_14_]]#3 : tensor<4xi32>, index, tensor<4xi32>, index +// CHECK: } +// CHECK: return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/tensor_indices_loop_iterarg_with_masks.mlir b/test/Conversion/TritonToStructured/tensor_indices_loop_iterarg_with_masks.mlir new file mode 100644 index 00000000..3c96c2d3 --- /dev/null +++ b/test/Conversion/TritonToStructured/tensor_indices_loop_iterarg_with_masks.mlir @@ -0,0 +1,51 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize --cse %s | FileCheck %s + +// IR from python/examples/test_tensor_index_iterargs.py +module { + tt.func public @addptr_with_masks(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %cst = arith.constant dense<-1.100000e+01> : tensor<4xf32> + %c1_i32 = arith.constant 1 : i32 + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<4> : tensor<4xi32> + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = tt.splat %arg2 : i32 -> tensor<4xi32> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %4:2 = scf.for %arg3 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg4 = %0, %arg5 = %0) -> (tensor<4xi32>, tensor<4xi32>) : i32 { + %5 = arith.cmpi slt, %arg4, %1 : tensor<4xi32> + %6 = tt.addptr %2, %arg4 : tensor<4x!tt.ptr>, tensor<4xi32> + %7 = tt.load %6, %5, %cst : tensor<4x!tt.ptr> + %8 = tt.addptr %3, %arg5 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %8, %7 : tensor<4x!tt.ptr> + %9 = arith.addi %arg4, %cst_0 : tensor<4xi32> + %10 = arith.addi %arg5, %cst_0 : tensor<4xi32> + scf.yield %9, %10 : tensor<4xi32>, tensor<4xi32> + } + tt.return + } +} + +// CHECK: tt.func public @addptr_with_masks([[arg0_:.+]]: !tt.ptr, [[arg1_:.+]]: !tt.ptr, [[arg2_:.+]]: i32) attributes {noinline = false} { +// CHECK-DAG: [[CST_minus_1_dot_100000_:%.+]] = arith.constant -1.100000e+01 : f32 +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_4_1_:%.+]] = arith.constant 4 : i32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_0_:%.+]]:2 = scf.for [[VAR_arg3_:%.+]] = [[CST_0_]] to [[CST_4_1_]] step [[CST_1_]] iter_args([[VAR_arg4_:%.+]] = [[CST_0_1_]], [[VAR_arg5_:%.+]] = [[CST_0_1_]]) -> (index, index) : i32 { +// CHECK-DAG: [[VAR_1_:%.+]] = tts.make_tptr [[arg0_]] to sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}}, offsets: {{.}}[[VAR_arg4_]]{{.}}, shape: [0], order: [] : to tensor<4x!tt.ptr> +// CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_arg4_]], [[CST_4_]] : index +// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[arg2_]] : i32 to index +// CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_arg4_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = "tts.load"([[VAR_1_]], [[VAR_5_]], [[CST_minus_1_dot_100000_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<4x!tt.ptr>, index, f32) -> tensor<4xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = tts.make_tptr [[arg1_]] to sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}}, offsets: {{.}}[[VAR_arg5_]]{{.}}, shape: [0], order: [] : to tensor<4x!tt.ptr> +// CHECK: "tts.store"([[VAR_7_]], [[VAR_6_]]) <{static_mask_dims = array}> : (tensor<4x!tt.ptr>, tensor<4xf32>) -> () +// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_arg5_]], [[CST_4_]] : index +// CHECK: scf.yield [[VAR_2_]], [[VAR_8_]] : index, index +// CHECK: } +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/tensor_indices_loop_iterargs_nested.mlir b/test/Conversion/TritonToStructured/tensor_indices_loop_iterargs_nested.mlir new file mode 100644 index 00000000..d868714c --- /dev/null +++ b/test/Conversion/TritonToStructured/tensor_indices_loop_iterargs_nested.mlir @@ -0,0 +1,77 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize --cse %s | FileCheck %s + +// IR from python/examples/test_tensor_index_iterargs.py +module { + tt.func public @tensor_indices_nested(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c3_i32 = arith.constant 3 : i32 + %cst = arith.constant dense<4> : tensor<4xi32> + %c2_i32 = arith.constant 2 : i32 + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %2 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %0, %arg4 = %0) -> (tensor<4xi32>, tensor<4xi32>) : i32 { + %4 = arith.muli %arg2, %c2_i32 : i32 + %5 = tt.splat %4 : i32 -> tensor<4xi32> + %6 = arith.addi %arg3, %5 : tensor<4xi32> + %7 = tt.addptr %1, %6 : tensor<4x!tt.ptr>, tensor<4xi32> + %8 = tt.load %7 : tensor<4x!tt.ptr> + %9 = tt.addptr %2, %arg4 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %9, %8 : tensor<4x!tt.ptr> + %10 = arith.addi %6, %cst : tensor<4xi32> + %11 = arith.addi %arg4, %cst : tensor<4xi32> + %12:2 = scf.for %arg5 = %c0_i32 to %c3_i32 step %c1_i32 iter_args(%arg6 = %10, %arg7 = %11) -> (tensor<4xi32>, tensor<4xi32>) : i32 { + %13 = arith.muli %arg5, %c3_i32 : i32 + %14 = tt.splat %13 : i32 -> tensor<4xi32> + %15 = arith.addi %arg6, %14 : tensor<4xi32> + %16 = tt.addptr %1, %15 : tensor<4x!tt.ptr>, tensor<4xi32> + %17 = tt.load %16 : tensor<4x!tt.ptr> + %18 = tt.addptr %2, %arg7 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %18, %17 : tensor<4x!tt.ptr> + %19 = arith.addi %15, %cst : tensor<4xi32> + %20 = arith.addi %arg7, %cst : tensor<4xi32> + scf.yield %19, %20 : tensor<4xi32>, tensor<4xi32> + } + scf.yield %12#0, %12#1 : tensor<4xi32>, tensor<4xi32> + } + tt.return + } +} + +// CHECK: tt.func public @tensor_indices_nested([[arg0_:.+]]: !tt.ptr, [[arg1_:.+]]: !tt.ptr) attributes {noinline = false} { +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : i32 +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_0_:%.+]]:2 = scf.for [[VAR_arg2_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg3_:%.+]] = [[CST_0_1_]], [[VAR_arg4_:%.+]] = [[CST_0_1_]]) -> (index, index) : i32 { +// CHECK-DAG: [[VAR_1_:%.+]] = arith.muli [[VAR_arg2_]], [[CST_2_]] : i32 +// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK: [[VAR_3_:%.+]] = arith.addi [[VAR_arg3_]], [[VAR_2_]] : index +// CHECK: [[VAR_4_:%.+]] = tts.make_tptr [[arg0_]] to sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}}, offsets: {{.}}[[VAR_3_]]{{.}}, shape: [0], order: [] : to tensor<4x!tt.ptr> +// CHECK-DAG: [[VAR_5_:%.+]] = "tts.load"([[VAR_4_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<4x!tt.ptr>) -> tensor<4xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = tts.make_tptr [[arg1_]] to sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}}, offsets: {{.}}[[VAR_arg4_]]{{.}}, shape: [0], order: [] : to tensor<4x!tt.ptr> +// CHECK: "tts.store"([[VAR_6_]], [[VAR_5_]]) <{static_mask_dims = array}> : (tensor<4x!tt.ptr>, tensor<4xf32>) -> () +// CHECK-DAG: [[VAR_7_:%.+]] = arith.addi [[VAR_3_]], [[CST_4_]] : index +// CHECK-DAG: [[VAR_8_:%.+]] = arith.addi [[VAR_arg4_]], [[CST_4_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_9_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_3_]] step [[CST_1_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_7_]], [[VAR_arg7_:%.+]] = [[VAR_8_]]) -> (index, index) : i32 { +// CHECK-DAG: [[VAR_10_:%.+]] = arith.muli [[VAR_arg5_]], [[CST_3_]] : i32 +// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[VAR_10_]] : i32 to index +// CHECK: [[VAR_12_:%.+]] = arith.addi [[VAR_arg6_]], [[VAR_11_]] : index +// CHECK: [[VAR_13_:%.+]] = tts.make_tptr [[arg0_]] to sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}}, offsets: {{.}}[[VAR_12_]]{{.}}, shape: [0], order: [] : to tensor<4x!tt.ptr> +// CHECK-DAG: [[VAR_14_:%.+]] = "tts.load"([[VAR_13_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<4x!tt.ptr>) -> tensor<4xf32> +// CHECK-DAG: [[VAR_15_:%.+]] = tts.make_tptr [[arg1_]] to sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}}, offsets: {{.}}[[VAR_arg7_]]{{.}}, shape: [0], order: [] : to tensor<4x!tt.ptr> +// CHECK: "tts.store"([[VAR_15_]], [[VAR_14_]]) <{static_mask_dims = array}> : (tensor<4x!tt.ptr>, tensor<4xf32>) -> () +// CHECK-DAG: [[VAR_16_:%.+]] = arith.addi [[VAR_12_]], [[CST_4_]] : index +// CHECK-DAG: [[VAR_17_:%.+]] = arith.addi [[VAR_arg7_]], [[CST_4_]] : index +// CHECK: scf.yield [[VAR_16_]], [[VAR_17_]] : index, index +// CHECK: } +// CHECK: scf.yield [[VAR_9_]]#0, [[VAR_9_]]#1 : index, index +// CHECK: } +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/tensor_indices_loop_iterargs_not_used_ptranalysis_e2e.mlir b/test/Conversion/TritonToStructured/tensor_indices_loop_iterargs_not_used_ptranalysis_e2e.mlir new file mode 100644 index 00000000..26732ed7 --- /dev/null +++ b/test/Conversion/TritonToStructured/tensor_indices_loop_iterargs_not_used_ptranalysis_e2e.mlir @@ -0,0 +1,43 @@ +// IR obtained from "test_integer_tensor" in python/examples/test_tensor_index_iterargs.py + +// RUN: triton-shared-opt --triton-to-structured --cse --canonicalize --remove-dead-values %s | FileCheck %s +module { + tt.func public @test_1(%arg0: !tt.ptr) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<4> : tensor<4xi32> + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %2:2 = scf.for %arg1 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg2 = %0, %arg3 = %0) -> (tensor<4xi32>, tensor<4xi32>) : i32 { + %3 = tt.addptr %1, %arg2 : tensor<4x!tt.ptr>, tensor<4xi32> + %4 = arith.sitofp %arg3 : tensor<4xi32> to tensor<4xf32> + tt.store %3, %4 : tensor<4x!tt.ptr> + %5 = arith.addi %arg2, %cst : tensor<4xi32> + %6 = arith.addi %arg3, %cst : tensor<4xi32> + scf.yield %5, %6 : tensor<4xi32>, tensor<4xi32> + } + tt.return + } +} + +// CHECK: tt.func public @test_1([[arg0_:.+]]: !tt.ptr) attributes {noinline = false} { +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<4> : tensor<4xi32> +// CHECK-DAG: [[VAR_0_:%.+]] = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]]:2 = scf.for [[VAR_arg1_:%.+]] = [[CST_0_1_]] to [[CST_2_]] step [[CST_1_1_]] iter_args([[VAR_arg2_:%.+]] = [[CST_0_]], [[VAR_arg3_:%.+]] = [[VAR_0_]]) -> (index, tensor<4xi32>) : i32 { +// CHECK-DAG: [[VAR_2_:%.+]] = tts.make_tptr [[arg0_]] to sizes: [4], strides: {{.}}[[CST_1_]]{{.}}, offsets: {{.}}[[VAR_arg2_]]{{.}}, shape: [0], order: [] : to tensor<4x!tt.ptr> +// CHECK-DAG: [[VAR_3_:%.+]] = arith.sitofp [[VAR_arg3_]] : tensor<4xi32> to tensor<4xf32> +// CHECK: "tts.store"([[VAR_2_]], [[VAR_3_]]) <{static_mask_dims = array}> : (tensor<4x!tt.ptr>, tensor<4xf32>) -> () +// CHECK-DAG: [[VAR_4_:%.+]] = arith.addi [[VAR_arg3_]], [[VAR_cst_]] : tensor<4xi32> +// CHECK-DAG: [[VAR_5_:%.+]] = arith.addi [[VAR_arg2_]], [[CST_4_]] : index +// CHECK: scf.yield [[VAR_5_]], [[VAR_4_]] : index, tensor<4xi32> +// CHECK: } +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/tensor_indices_loop_iterargs_not_used_ptranalysis_prepass.mlir b/test/Conversion/TritonToStructured/tensor_indices_loop_iterargs_not_used_ptranalysis_prepass.mlir new file mode 100644 index 00000000..5219da88 --- /dev/null +++ b/test/Conversion/TritonToStructured/tensor_indices_loop_iterargs_not_used_ptranalysis_prepass.mlir @@ -0,0 +1,44 @@ +// IR obtained from "test_integer_tensor" in python/examples/test_tensor_index_iterargs.py + +// RUN: triton-shared-opt --triton-to-structured="run-prepass-only" %s | FileCheck %s +module { + tt.func public @test_1(%arg0: !tt.ptr) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<4> : tensor<4xi32> + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %2:2 = scf.for %arg1 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg2 = %0, %arg3 = %0) -> (tensor<4xi32>, tensor<4xi32>) : i32 { + %3 = tt.addptr %1, %arg2 : tensor<4x!tt.ptr>, tensor<4xi32> + %4 = arith.sitofp %arg3 : tensor<4xi32> to tensor<4xf32> + tt.store %3, %4 : tensor<4x!tt.ptr> + %5 = arith.addi %arg2, %cst : tensor<4xi32> + %6 = arith.addi %arg3, %cst : tensor<4xi32> + scf.yield %5, %6 : tensor<4xi32>, tensor<4xi32> + } + tt.return + } +} + +// CHECK: tt.func public @test_1([[arg0_:.+]]: !tt.ptr) attributes {noinline = false} { +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<4> : tensor<4xi32> +// CHECK-DAG: [[VAR_0_:%.+]] = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> +// CHECK-DAG: [[VAR_1_:%.+]] = tt.splat [[arg0_]] : !tt.ptr -> tensor<4x!tt.ptr> +// CHECK: [[structured_:%.+]], [[offsets_:%.+]], [[VAR_strides_:%.+]] = "tts.get_structured_state"([[VAR_0_]]) <{resultSegmentSizes = array}> : (tensor<4xi32>) -> (tensor<4xi32>, index, index) +// CHECK: [[structured_0_:%.+]], [[offsets_1_:%.+]], [[VAR_strides_2_:%.+]] = "tts.get_structured_state"([[VAR_0_]]) <{resultSegmentSizes = array}> : (tensor<4xi32>) -> (tensor<4xi32>, index, index) +// CHECK-DAG: [[VAR_2_:%.+]]:6 = scf.for [[VAR_arg1_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg2_:%.+]] = [[structured_]], [[VAR_arg3_:%.+]] = [[offsets_]], [[VAR_arg4_:%.+]] = [[VAR_strides_]], [[VAR_arg5_:%.+]] = [[structured_0_]], [[VAR_arg6_:%.+]] = [[offsets_1_]], [[VAR_arg7_:%.+]] = [[VAR_strides_2_]]) -> (tensor<4xi32>, index, index, tensor<4xi32>, index, index) : i32 { +// CHECK-DAG: [[VAR_3_:%.+]] = tt.addptr [[VAR_1_]], [[VAR_arg2_]] : tensor<4x!tt.ptr>, tensor<4xi32> +// CHECK-DAG: [[VAR_4_:%.+]] = arith.sitofp [[VAR_arg5_]] : tensor<4xi32> to tensor<4xf32> +// CHECK: tt.store [[VAR_3_]], [[VAR_4_]] : tensor<4x!tt.ptr> +// CHECK-DAG: [[VAR_5_:%.+]] = arith.addi [[VAR_arg2_]], [[VAR_cst_]] : tensor<4xi32> +// CHECK-DAG: [[VAR_6_:%.+]] = arith.addi [[VAR_arg5_]], [[VAR_cst_]] : tensor<4xi32> +// CHECK: [[structured_3_:%.+]], [[offsets_4_:%.+]], [[VAR_strides_5_:%.+]] = "tts.get_structured_state"([[VAR_5_]]) <{resultSegmentSizes = array}> : (tensor<4xi32>) -> (tensor<4xi32>, index, index) +// CHECK: [[structured_6_:%.+]], [[offsets_7_:%.+]], [[VAR_strides_8_:%.+]] = "tts.get_structured_state"([[VAR_6_]]) <{resultSegmentSizes = array}> : (tensor<4xi32>) -> (tensor<4xi32>, index, index) +// CHECK: scf.yield [[structured_3_]], [[offsets_4_]], [[VAR_strides_5_]], [[structured_6_]], [[offsets_7_]], [[VAR_strides_8_]] : tensor<4xi32>, index, index, tensor<4xi32>, index, index +// CHECK: } +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/wraparound_unsupported_add_offset.mlir b/test/Conversion/TritonToStructured/wraparound_unsupported_add_offset.mlir index d62898fc..f60b4a20 100644 --- a/test/Conversion/TritonToStructured/wraparound_unsupported_add_offset.mlir +++ b/test/Conversion/TritonToStructured/wraparound_unsupported_add_offset.mlir @@ -57,10 +57,11 @@ module { } } -// CHECK: tt.func public @wrap_side_by_side_masked_loop_01234567([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { +// CHECK: tt.func public @wrap_side_by_side_masked_loop_01234567([[arg0_:.+]]: !tt.ptr, [[arg1_:.+]]: !tt.ptr, [[arg2_:.+]]: i32, [[arg3_:.+]]: i32, [[arg4_:.+]]: i32, [[arg5_:.+]]: i32, [[arg6_:.+]]: i32, [[arg7_:.+]]: i32) { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<-9.900000e+01> : tensor<4x4xf32> // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 // CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<2> : tensor<4x1xi32> // CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<6> : tensor<4xi32> @@ -69,43 +70,41 @@ module { // CHECK-DAG: [[VAR_0_:%.+]] = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_1_:%.+]] = arith.addi [[VAR_0_]], [[VAR_cst_2_]] : tensor<4xi32> -// CHECK-DAG: [[VAR_2_:%.+]] = tt.splat [[PARAM_3_]] : i32 -> tensor<4xi32> +// CHECK-DAG: [[VAR_2_:%.+]] = tt.splat [[arg3_]] : i32 -> tensor<4xi32> // CHECK: [[VAR_3_:%.+]] = arith.remsi [[VAR_0_]], [[VAR_2_]] : tensor<4xi32> // CHECK-DAG: [[VAR_4_:%.+]] = arith.addi [[VAR_3_]], [[VAR_cst_1_]] : tensor<4xi32> // CHECK-DAG: [[VAR_5_:%.+]] = tt.expand_dims [[VAR_1_]] {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> -// CHECK-DAG: [[VAR_6_:%.+]] = tt.splat [[PARAM_4_]] : i32 -> tensor<4x1xi32> +// CHECK-DAG: [[VAR_6_:%.+]] = tt.splat [[arg4_]] : i32 -> tensor<4x1xi32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_7_:%.+]] = arith.muli [[VAR_5_]], [[VAR_6_]] : tensor<4x1xi32> // CHECK-DAG: [[VAR_8_:%.+]] = tt.expand_dims [[VAR_4_]] {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> -// CHECK-DAG: [[VAR_9_:%.+]] = tt.splat [[PARAM_5_]] : i32 -> tensor<1x4xi32> +// CHECK-DAG: [[VAR_9_:%.+]] = tt.splat [[arg5_]] : i32 -> tensor<1x4xi32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_10_:%.+]] = arith.muli [[VAR_8_]], [[VAR_9_]] : tensor<1x4xi32> // CHECK-DAG: [[VAR_11_:%.+]] = tt.broadcast [[VAR_7_]] : tensor<4x1xi32> -> tensor<4x4xi32> // CHECK: [[VAR_12_:%.+]] = tt.broadcast [[VAR_10_]] : tensor<1x4xi32> -> tensor<4x4xi32> // CHECK-DAG: [[VAR_13_:%.+]] = arith.addi [[VAR_11_]], [[VAR_12_]] : tensor<4x4xi32> -// CHECK-DAG: [[VAR_14_:%.+]] = tt.splat [[PARAM_0_]] : !tt.ptr -> tensor<4x4x!tt.ptr> +// CHECK-DAG: [[VAR_14_:%.+]] = tt.splat [[arg0_]] : !tt.ptr -> tensor<4x4x!tt.ptr> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_15_:%.+]] = tt.addptr [[VAR_14_]], [[VAR_13_]] : tensor<4x4x!tt.ptr>, tensor<4x4xi32> // CHECK-DAG: [[VAR_16_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> -// CHECK-DAG: [[VAR_17_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index -// CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index +// CHECK-DAG: [[VAR_17_:%.+]] = arith.index_cast [[arg6_]] : i32 to index +// CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[arg7_]] : i32 to index +// CHECK: [[VAR_19_:%.+]] = arith.cmpi slt, [[VAR_16_]], [[VAR_cst_0_]] : tensor<4x1xi32> +// CHECK-DAG: [[VAR_20_:%.+]] = tt.broadcast [[VAR_19_]] : tensor<4x1xi1> -> tensor<4x4xi1> +// CHECK-DAG: [[VAR_21_:%.+]] = arith.muli [[arg4_]], [[CST_4_]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_19_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [4, 4], strides: {{.}}[[VAR_17_]], [[VAR_18_]]{{.}}, offsets: [0, 0], shape: [0, 0], order: [] : to tensor<4x4x!tt.ptr> -// CHECK-DAG: [[VAR_20_:%.+]] = arith.cmpi slt, [[VAR_16_]], [[VAR_cst_0_]] : tensor<4x1xi32> +// CHECK-DAG: [[VAR_22_:%.+]] = tt.splat [[VAR_21_]] : i32 -> tensor<4x4xi32> +// CHECK-DAG: [[VAR_23_:%.+]] = arith.muli [[arg5_]], [[CST_4_]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_21_:%.+]] = tt.broadcast [[VAR_20_]] : tensor<4x1xi1> -> tensor<4x4xi1> -// CHECK-DAG: [[VAR_22_:%.+]] = arith.muli [[PARAM_4_]], [[CST_4_]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_23_:%.+]] = tt.splat [[VAR_22_]] : i32 -> tensor<4x4xi32> -// CHECK-DAG: [[VAR_24_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_25_:%.+]] = tt.splat [[VAR_24_]] : i32 -> tensor<4x4xi32> -// CHECK-DAG: [[VAR_26_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg9_:%.+]] = [[VAR_15_]], [[VAR_arg10_:%.+]] = [[VAR_19_]]) -> (tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr>) : i32 { -// CHECK-DAG: [[LOAD_VAR_arg9_MEM_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_21_]], [[VAR_cst_]] : tensor<4x4x!tt.ptr> -// CHECK: tt.store [[VAR_arg10_]], [[LOAD_VAR_arg9_MEM_]] : tensor<4x4x!tt.ptr> -// CHECK-DAG: [[VAR_28_:%.+]] = tt.addptr [[VAR_arg9_]], [[VAR_23_]] : tensor<4x4x!tt.ptr>, tensor<4x4xi32> -// CHECK-DAG: [[VAR_29_:%.+]] = tt.addptr [[VAR_arg10_]], [[VAR_25_]] : tensor<4x4x!tt.ptr>, tensor<4x4xi32> -// CHECK: scf.yield [[VAR_28_]], [[VAR_29_]] : tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr> +// CHECK-DAG: [[VAR_24_:%.+]] = arith.index_cast [[VAR_23_]] : i32 to index +// CHECK-DAG: [[VAR_25_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_1_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg9_:%.+]] = [[VAR_15_]], [[VAR_arg10_:%.+]] = [[CST_0_]]) -> (tensor<4x4x!tt.ptr>, index) : i32 { +// CHECK-DAG: [[VAR_26_:%.+]] = tts.make_tptr [[arg1_]] to sizes: [4, 4], strides: {{.}}[[VAR_17_]], [[VAR_18_]]{{.}}, offsets: {{.}}[[VAR_arg10_]], [[CST_0_]]{{.}}, shape: [0, 0], order: [] : to tensor<4x4x!tt.ptr> +// CHECK-DAG: [[LOAD_VAR_arg9_MEM_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_20_]], [[VAR_cst_]] : tensor<4x4x!tt.ptr> +// CHECK: "tts.store"([[VAR_26_]], [[LOAD_VAR_arg9_MEM_]]) <{static_mask_dims = array}> : (tensor<4x4x!tt.ptr>, tensor<4x4xf32>) -> () +// CHECK-DAG: [[VAR_28_:%.+]] = tt.addptr [[VAR_arg9_]], [[VAR_22_]] : tensor<4x4x!tt.ptr>, tensor<4x4xi32> +// CHECK-DAG: [[VAR_29_:%.+]] = arith.addi [[VAR_arg10_]], [[VAR_24_]] : index +// CHECK: scf.yield [[VAR_28_]], [[VAR_29_]] : tensor<4x4x!tt.ptr>, index // CHECK: } // CHECK: tt.return // CHECK: }