diff --git a/include/triton-shared/Analysis/MaskAnalysis.h b/include/triton-shared/Analysis/MaskAnalysis.h index 3f32f4c8..6d67112f 100644 --- a/include/triton-shared/Analysis/MaskAnalysis.h +++ b/include/triton-shared/Analysis/MaskAnalysis.h @@ -14,6 +14,7 @@ #include "mlir/Support/LogicalResult.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/Support/LogicalResult.h" #include @@ -137,6 +138,9 @@ struct MaskState { // dimension that contains the range. LogicalResult parseExpandDims(triton::ExpandDimsOp expandDimsOp, const Location loc, OpBuilder &builder); + + LogicalResult parseLoopIterArg(Value v, const Location loc, + OpBuilder &builder); }; } // namespace triton diff --git a/include/triton-shared/AnalysisStructured/PtrAnalysis.h b/include/triton-shared/AnalysisStructured/PtrAnalysis.h index 1185184a..d104cd77 100644 --- a/include/triton-shared/AnalysisStructured/PtrAnalysis.h +++ b/include/triton-shared/AnalysisStructured/PtrAnalysis.h @@ -12,6 +12,8 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -90,11 +92,10 @@ class PtrAnalysis { scf::ForOp forOp, size_t ptrArgIndex, const PtrState &state, llvm::function_ref getReplacementVal); -public: - using IndexMapSet = std::map>; + DenseSet maybeStructuredArgs; - IndexMapSet levelToBlockArgIndex; - int level = 0; +public: + void initializeMaybeStructuredArgs(Operation *op); llvm::SmallDenseMap knownPtrs; diff --git a/include/triton-shared/Conversion/TritonToStructured/Passes.td b/include/triton-shared/Conversion/TritonToStructured/Passes.td index 65798de5..89488d64 100644 --- a/include/triton-shared/Conversion/TritonToStructured/Passes.td +++ b/include/triton-shared/Conversion/TritonToStructured/Passes.td @@ -8,7 +8,9 @@ def TritonToStructured : Pass<"triton-to-structured", "mlir::ModuleOp"> { let constructor = "triton::createTritonToStructuredPass()"; let options = [ Option<"runPrepassOnly", "run-prepass-only", "bool", /*default*/"false", - "Only run the pre-processing pass which inserts tts.get_structured_state ops used in scf.for"> + "Only run the pre-processing pass which inserts tts.get_structured_state ops used in scf.for">, + Option<"skipPrepass", "skip-prepass", "bool", /*default*/"false", + "Skip the prepass"> ]; } diff --git a/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td b/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td index 743648ce..c0f89bfc 100644 --- a/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td +++ b/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td @@ -126,11 +126,11 @@ def TTS_GetStructuredStateOp : TTS_Op<"get_structured_state", [AttrSizedResultSe let summary = "Placeholder for the structured pointer states computed during PtrAnalysis."; let description = "Used to pass the offsets and strides to scf.for op to simplify IR rewrites."; - let arguments = (ins TT_PtrLike:$ptr); - let results = (outs TT_PtrLike:$structuredPtr, Variadic:$offsets, Variadic:$strides); + let arguments = (ins AnyTypeOf<[TT_PtrLike, I32Tensor]>:$input); + let results = (outs AnyTypeOf<[TT_PtrLike, I32Tensor]>:$structured, Variadic:$offsets, Variadic:$strides); let builders = [ - OpBuilder<(ins "Value":$ptr)>, + OpBuilder<(ins "Value":$input)>, ]; let extraClassDeclaration = [{ diff --git a/lib/Analysis/MaskAnalysis.cpp b/lib/Analysis/MaskAnalysis.cpp index 4de7ae5e..5dd4bc74 100644 --- a/lib/Analysis/MaskAnalysis.cpp +++ b/lib/Analysis/MaskAnalysis.cpp @@ -8,12 +8,22 @@ #include "triton-shared/Analysis/MaskAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Support/LogicalResult.h" + #include "triton-shared/Analysis/OpFoldResultUtils.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" +#include + namespace mlir { namespace triton { @@ -38,6 +48,8 @@ LogicalResult MaskState::parse(Value operand, const Location loc, return this->parseSplat(op, loc, builder); } else if (auto op = operand.getDefiningOp()) { return this->parseExpandDims(op, loc, builder); + } else if (!operand.getDefiningOp()) { + return this->parseLoopIterArg(operand, loc, builder); } else if (auto op = operand.getDefiningOp()) { return this->parseExtSI(op, loc, builder); } else { @@ -109,8 +121,8 @@ static memref::SubViewOp createSubview(Value src, Location loc, OpBuilder &b, // + + | // +++++++++++++++++------- // -// If we simply take the subview of `buffer_tmp`, this requires an extra buffer -// to just hold the temporary result. +// If we simply take the subview of `buffer_tmp`, this requires an extra +// buffer to just hold the temporary result. // // So we can subview into block1 and block2 directly. There are 2 cases: // + subview only spans block1 @@ -131,8 +143,8 @@ static memref::SubViewOp createSubview(Value src, Location loc, OpBuilder &b, // Let (row, col1) and (row, col2) be the dimensions of block1 and block2, // respectively. // -// Let (rowFull, colFull), (rowView1, colView1) and (rowView2, colView2) be the -// dimensions of the full subview, sv1, and sv2, respectively. +// Let (rowFull, colFull), (rowView1, colView1) and (rowView2, colView2) be +// the dimensions of the full subview, sv1, and sv2, respectively. // // + colView1 = min(colFull, col1) // + colView2 = colFull - colView1 @@ -342,6 +354,58 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc, return success(); } +LogicalResult MaskState::parseLoopIterArg(Value v, const Location loc, + OpBuilder &builder) { + assert(!v.getDefiningOp()); + + auto forOp = llvm::dyn_cast(v.getParentRegion()->getParentOp()); + + if (!forOp) { + return failure(); + } + + // TODO: This implementation does not work with nested loops + if (forOp->getParentOfType()) { + return failure(); + } + + auto it = llvm::find(forOp.getRegionIterArgs(), v); + if (it == forOp.getRegionIterArgs().end()) { + return failure(); + } + + auto argIndex = std::distance(forOp.getRegionIterArgs().begin(), it); + auto initArg = forOp.getInitArgs()[argIndex]; + if (auto getStateOp = initArg.getDefiningOp()) { + auto tritonValue = getStateOp->getOperand(0); + MaskState lhsState; + if (failed(lhsState.parse(tritonValue, loc, builder))) { + return failure(); + } + + // This is a bit of a hack!! + // + // The offsets and dimensions of a MaskState can now depend on a loop's + // iter-arg. + // + // Because the PtrAnalysis's pre-pass already sets up the offsets, + // we can create a new MaskState for each loop iteration by adding the + // original MaskState with the current iter-arg, which is at `argIndex + + // 1`. + // + // This will not work for nested loop scenarios, which would need a + // more robust implementation. + if (failed(this->addStateScalar( + lhsState, forOp.getRegionIterArgs()[argIndex + 1], loc, builder))) { + return failure(); + } + + return success(); + } + + return failure(); +} + LogicalResult MaskState::parseMakeRange(triton::MakeRangeOp rangeOp, const Location loc, OpBuilder &builder) { diff --git a/lib/AnalysisStructured/PtrAnalysis.cpp b/lib/AnalysisStructured/PtrAnalysis.cpp index 967e3578..e3a2deea 100644 --- a/lib/AnalysisStructured/PtrAnalysis.cpp +++ b/lib/AnalysisStructured/PtrAnalysis.cpp @@ -11,8 +11,10 @@ #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/IR/Visitors.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "triton-shared/Analysis/MaskAnalysis.h" #include "triton-shared/Analysis/OpFoldResultUtils.h" @@ -28,9 +30,12 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" #include #include #include +#include +#include #include #define DEBUG_TYPE "triton-ptr-analysis" @@ -566,6 +571,7 @@ LogicalResult PtrAnalysis::visitOperandAddptr(triton::AddPtrOp addptrOp, PtrState ptrState; if (visitOperand(addptrOp.getPtr(), ptrState, addptrOp.getLoc(), builder) .failed()) { + // assert(0); return failure(); } @@ -760,10 +766,13 @@ LogicalResult PtrAnalysis::visitOperand(Value operand, PtrState &state, } else if (auto op = operand.getDefiningOp()) { return visitOperandForOp(op, operand, state, loc, builder); } else if (!operand.getDefiningOp()) { + if (!knownPtrs.contains(operand)) { + return failure(); + } + // This operand must be an iter-arg of an inner-loop in a multiple-level // nested loop, which means its PtrState must have already been populated // during rewriteForOp of the parent loop. - assert(knownPtrs.contains(operand)); state = knownPtrs[operand]; return success(); } else { @@ -952,7 +961,6 @@ FailureOr PtrAnalysis::getLoopIterArgPtrState(scf::ForOp forOp, FailureOr PtrAnalysis::getLoopResultPtrState(scf::ForOp forOp, size_t index) { - auto state = getLoopInitArgPtrState(forOp, index); if (failed(state)) { return failure(); @@ -964,22 +972,25 @@ FailureOr PtrAnalysis::getLoopResultPtrState(scf::ForOp forOp, } LogicalResult PtrAnalysis::rewriteForOp(scf::ForOp op) { - for (auto [i, arg] : llvm::enumerate(op.getInitArgs())) { - if (!isPointerType(arg.getType())) { + for (auto [i, arg] : llvm::enumerate(op.getRegionIterArgs())) { + if (!maybeStructuredArgs.contains(arg)) { continue; } auto state = getLoopIterArgPtrState(op, i); if (failed(state)) { - op.emitError( + // Because the maybeStructuredArgs may contain values that are not + // considered structured by PtrAnalysis, failing to retrieve the PtrState + // should not fail the rewrite process. + // We emit an error for diagnostics and debugging purposes. + op->emitWarning( "Rewrite for-op failed. Could not find PtrState for iter-arg index " + std::to_string(i)); - return failure(); + continue; } // Save the current init arg's PtrState - auto key = op.getRegionIterArgs()[i]; - knownPtrs[key] = state.value(); + knownPtrs[arg] = state.value(); // For tensors of pointers, create a tts.make_tptr at the beginning of the // loop body that correspond to this region iter arg. In case it is used @@ -1004,10 +1015,14 @@ LogicalResult PtrAnalysis::rewriteForOp(scf::ForOp op) { // beginning of the loop body. We don't lower tt.load and tt.store on // scalars in this pass; pointer arithmetics can also just use the // original pointer. - if (state->getRank() != 0) { - OpBuilder builder(op.getRegion()); - auto maketptrOp = state->createTTSMakeTensorPtrOp(builder, op.getLoc()); - ptrMap.map(key, maketptrOp.getResult()); + // Note that there can be tensor of indices in iter-arg, so we only create + // the make_tensor_ptr op when the arg is of pointer type. + if (isPointerType(arg.getType())) { + if (state->getRank() != 0) { + OpBuilder builder(op.getRegion()); + auto maketptrOp = state->createTTSMakeTensorPtrOp(builder, op.getLoc()); + ptrMap.map(arg, maketptrOp.getResult()); + } } } @@ -1023,23 +1038,23 @@ LogicalResult PtrAnalysis::rewriteForOp(scf::ForOp op) { LogicalResult PtrAnalysis::rewriteGetStructuredStateOp(tts::GetStructuredStateOp op) { - auto tritonPtr = op->getOperand(0); + auto tritonValue = op->getOperand(0); - // If this pointer isn't known, it means PtrAnalysis has failed to analyze - // this pointer. In such cases, simply remap all uses of the - // structured-pointer back to its original pointer. - if (!knownPtrs.contains(tritonPtr)) { + // If this triton value isn't known, it means PtrAnalysis has failed to + // analyze this pointer. In such cases, simply remap all uses of the + // structured value back to its original triton value. + if (!knownPtrs.contains(tritonValue)) { op.emitRemark( "Rewrite GetStructuredStateOp failed. Could not find PtrState."); - op.getResult(0).replaceAllUsesWith(tritonPtr); + op.getResult(0).replaceAllUsesWith(tritonValue); return failure(); } - tts::PtrState state = knownPtrs[tritonPtr]; - assert(ptrMap.contains(tritonPtr)); - Value remappedPtr = ptrMap.lookup(tritonPtr); + tts::PtrState state = knownPtrs[tritonValue]; + Value remappedValue = + ptrMap.contains(tritonValue) ? ptrMap.lookup(tritonValue) : tritonValue; - SmallVector replacements{remappedPtr}; + SmallVector replacements{remappedValue}; if (state.getRank() == 0) { // For scalar pointers, the scalar contains the offset and is the only @@ -1131,6 +1146,86 @@ LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op) { return success(); } +// Structured values from the TritonStructuredDialect have offsets and strides +// that might change in each loop iteration and hence will appear in an scf.for +// iter-args like so: +// +// %structured, %offsets, %strides = tts.get_structured_state +// scf.for (%arg0 = %structured, %arg1 = %offsets, %arg2 = %strides) { +// %a = %arg0 + 1 +// %b = %b + 2 +// scf.for (%arg1 = %b) { +// ... +// } +// } +// +// In `rewriteForOp`, we have to recognize such structured values in order to +// rewrite their PtrState accordingly. Previously, only values of Pointer-like +// type (e.g.: tensor> or tt.ptr>), so detecting these values +// is as easy as checking the type. +// +// Now, tensor of indices could also appear in a loop's iter-arg. To reliably +// detect all such cases, we perform a BFS-like traversal of the IR where the +// sources are the results of `tts.get_structured_state`. All values that +// originate from the results of `tts.get_structured_state` are consider +// "maybeStructured". If a loop's iter-arg is considered "maybeStructured", we +// must set up their PtrState during `rewriteForOp`. +void PtrAnalysis::initializeMaybeStructuredArgs(Operation *op) { + std::queue q; + DenseSet visited; + + op->walk([&q, &visited](tts::GetStructuredStateOp getStateOp) { + Value value = getStateOp->getResult(0); + visited.insert(value); + q.push(value); + }); + + while (!q.empty()) { + auto v = q.front(); + q.pop(); + for (auto user : v.getUsers()) { + // scf.for is a special case. We have 2 set of values to consider: + // - iter-args + // - loop results + // for every init arg that originates from a `tts.get_structured_state` + // op, its corresponding iter-arg and loop result will also be considered + // "maybeStructured". + if (auto forOp = dyn_cast(user)) { + auto it = llvm::find(forOp.getInitArgs(), v); + + if (it == forOp.getInitArgs().end()) { + continue; + } + + auto argIndex = std::distance(forOp.getInitArgs().begin(), it); + auto iterArg = forOp.getRegionIterArg(argIndex); + auto tiedLoopRes = forOp.getTiedLoopResult(iterArg); + + SmallVector neighbors{iterArg, tiedLoopRes}; + for (auto neighbor : neighbors) { + maybeStructuredArgs.insert(neighbor); + if (!visited.contains(neighbor)) { + visited.insert(neighbor); + q.push(neighbor); + } + } + + } else { + for (auto res : user->getResults()) { + if (res.getType() != v.getType()) { + continue; + } + maybeStructuredArgs.insert(res); + if (!visited.contains(res)) { + visited.insert(res); + q.push(res); + } + } + } + } + } +} + LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op) { auto ptr = ptrMap.lookupOrNull(op.getPtr()); auto val = op.getValue(); @@ -1222,13 +1317,40 @@ LogicalResult PtrAnalysis::rewriteOp(Operation *rootOp) { .Case([&](auto forOp) { // `rewriteForOp` recursively visits its children, so regardless // whether the rewrite succeeds or not, we need to return "skip" so - // that the the walk does not visit the for-op's child operations the - // second time. + // that the the walk does not visit the for-op's child operations + // the second time. if (rewriteForOp(forOp).failed()) { forOp->emitRemark("PtrAnalysis: Failed to rewrite ForOp"); } return WalkResult::skip(); }) + .Case( + [&](tts::GetStructuredStateOp getStateOp) { + // For tensor of indices potentially being used in pointer + // arithmetic sequence, we need to manually populate the state of + // none already exists. + // This process is necessary because unlike triton pointers in a + // loop which always have a `tt.addptr` that triggers the rewrite + // process which includes generating the ops for updating offsets + // and strides, tensor of indices only have a simple `arith.addi` + // (or other arith ops). + // Without visiting these ops manually, the ops to update the + // offsets and strides would not be generated. + auto tritonValue = getStateOp->getOperand(0); + if (!knownPtrs.contains(tritonValue)) { + PtrState state; + OpBuilder b(getStateOp); + if (succeeded(visitOperand(tritonValue, state, + getStateOp->getLoc(), b))) { + knownPtrs[tritonValue] = state; + } else { + getStateOp->emitRemark("PtrAnalysis: Failed to populate ptr " + "state for tensor of indices"); + } + } + + return WalkResult::skip(); + }) .Default([&](auto) { return WalkResult::advance(); }); }); diff --git a/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp b/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp index 1f55f285..45ff8ab2 100644 --- a/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp +++ b/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp @@ -34,6 +34,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" #include #include @@ -84,7 +85,12 @@ class TritonToStructuredPass converter.addConversion([context](RankedTensorType tensorType, SmallVectorImpl &types) -> std::optional { - if (!isa(tensorType.getElementType())) { + // Important note: + // We only care about tensor of index / int (in addition to pointer type) + // because only values of int and index type can potentially be part of a + // pointer arithmetic sequence. + if (!isa(tensorType.getElementType()) && + !tensorType.getElementType().isIntOrIndex()) { // There's a subtle difference between returning failure() and // std::nullopt. From the documentation: // @@ -235,8 +241,15 @@ class TritonToStructuredPass // init args easier, especially with multiple levels of loops. // // Background: - // If a triton pointer is updated and returned in a scf.for op, it means + // + // PtrAnalysis computes a PtrState for every operand (or triton value) + // involved in a sequence of pointer arithmetic; some examples include: triton + // pointer, offsets (which could be a tensor of indices or just a simple index + // value). + // + // If a triton value is updated and returned in a scf.for op, it means // that we have to carry its offsets and strides in the scf.for's iterargs. + // // Previously, we have to manually rewrite the loops to include the // relevant information from a PtrState which was rather involved and // error-prone; this was also hard to scale up to multiple level of loops @@ -244,39 +257,46 @@ class TritonToStructuredPass // maintain. // // With the introduction of the prepass that inserts - // `tts.get_structured_state`, the return values of these ops, which include a - // triton pointer and its corresponding offsets and strides, will be used as - // "placeholders" into the scf.for's init-args. We leverage standard MLIR - // infrastructure 1->N conversion to perform this rewrite, which helps - // simplify the logic significantly. + // `tts.get_structured_state`. The return values of these ops, which include a + // triton value with its original result type and its corresponding offsets + // and strides, will be used as "placeholders" into the scf.for's init-args. + // We leverage standard MLIR infrastructure 1->N conversion to perform this + // rewrite, which helps simplify the logic significantly. // // After PtrAnalysis finishes, the return values of these // `tts.get_structured_state` ops will be remapped to the correct - // initialization of the pointer's offsets and strides through the pointer's + // initialization of the value's offsets and strides through the value's // computed PtrState. // // Implementation details: // In essence, what we really want to do in the prepass is, for every value - // of triton-pointer-like type (tt.ptr or tensor>), we want to - // create an op `tts.get_structured_state` that takes in the original triton - // pointer value and returns a series of values: + // of triton-pointer-like type (tt.ptr or tensor>) and tensor of + // indices (tensor) which might be used in a sequence of pointer + // arithmetic, we want to create an op `tts.get_structured_state` that takes + // in the original triton value and returns a series of values: // - // {triton_ptr, offset_0, offset_1, ..., stride_0, stride_1,...} + // {triton_value, offset_0, offset_1, ..., stride_0, stride_1,...} // // Applying the above conversion will also mean that any structural ops such // as scf.for and scf.yield that originally takes the triton pointer will - // then take {triton_ptr, offset_0, offset_1, ..., stride_0, stride_1,...}. + // then take {triton_value, offset_0, offset_1, ..., stride_0, stride_1,...}. // // The 1->N type conversion is a perfect fit for this transformation. // Unfortunately, we cannot do this is one pass, because the current 1->N // type conversion implementation for scf.for ops doesn't provide us with a - // way to detect that a type conversion is recursive. So a triton_ptr type - // that gets converted to a {triton_ptr, offset_0, offset_1, ..., stride_0, + // way to detect that a type conversion is recursive. So a triton_value type + // that gets converted to a {triton_value, offset_0, offset_1, ..., stride_0, // stride_1,...} will recursively trigger other conversions. // - // To fix this issue, we have to first convert triton_ptr to - // tuple. + // To fix this issue, we have to first convert triton_value to + // tuple. // Finally, we decompose these tuples into the desired sequence. + // + // Note that even though the type conversion happens for every integer tensor + // appearing in loops' iter-args, this conversion is reversible. If the + // integer tensor isn't used in a pointer arithmetic sequence, + // canonicalization will remove all the `tts.get_structured_state` ops and + // revert the IR back to its original form. LogicalResult runTritonToStructuredPrepass() { if (failed(convertToPointerTupleWithOffsetsAndStrides())) { return failure(); @@ -286,7 +306,7 @@ class TritonToStructuredPass } void runOnOperation() override { - if (failed(runTritonToStructuredPrepass())) { + if (!skipPrepass && failed(runTritonToStructuredPrepass())) { signalPassFailure(); return; } @@ -297,6 +317,8 @@ class TritonToStructuredPass auto moduleOp = getOperation(); mlir::tts::PtrAnalysis ptrAnalysis; + ptrAnalysis.initializeMaybeStructuredArgs(moduleOp); + if (failed(ptrAnalysis.rewriteOp(moduleOp))) { moduleOp->emitWarning("PtrAnalysis failed"); } diff --git a/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp b/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp index 76e51661..cf55d834 100644 --- a/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp +++ b/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp @@ -11,6 +11,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/LogicalResult.h" #include #include #include @@ -94,9 +95,8 @@ void StoreOp::build(OpBuilder &b, OperationState &state, Value ptr, Value value, } LogicalResult GetStructuredStateOp::verify() { - return success(); auto expectedOffsetAndStrideTypes = - getOffsetAndStrideTypes(getContext(), getStructuredPtr().getType()); + getOffsetAndStrideTypes(getContext(), getInput().getType()); if (!expectedOffsetAndStrideTypes.has_value()) { return failure(); @@ -112,8 +112,8 @@ LogicalResult GetStructuredStateOp::verify() { } void GetStructuredStateOp::build(OpBuilder &b, OperationState &state, - Value ptr) { - auto type = ptr.getType(); + Value val) { + auto type = val.getType(); // Builder cannot fail, so we default to empty offset and stride types. // The invalid op will be rejected by the verifier later. @@ -121,13 +121,12 @@ void GetStructuredStateOp::build(OpBuilder &b, OperationState &state, getOffsetAndStrideTypes(b.getContext(), type) .value_or(std::make_pair(SmallVector{}, SmallVector{})); - build(b, state, ptr.getType(), offsetTypes, strideTypes, ptr); + build(b, state, val.getType(), offsetTypes, strideTypes, val); } std::optional, SmallVector>> -GetStructuredStateOp::getOffsetAndStrideTypes(MLIRContext *context, - Type ptrLikeType) { - auto sizes = getOffsetAndStrideSegmentSizes(ptrLikeType); +GetStructuredStateOp::getOffsetAndStrideTypes(MLIRContext *context, Type type) { + auto sizes = getOffsetAndStrideSegmentSizes(type); if (!sizes.has_value()) { return std::nullopt; } @@ -137,21 +136,28 @@ GetStructuredStateOp::getOffsetAndStrideTypes(MLIRContext *context, } std::optional> -GetStructuredStateOp::getOffsetAndStrideSegmentSizes(Type ptrLikeType) { +GetStructuredStateOp::getOffsetAndStrideSegmentSizes(Type type) { int32_t offsetSegmentSize = 0; int32_t strideSegmentSize = 0; - // Unstructured pointers (tensor>) - if (auto tensorType = llvm::dyn_cast(ptrLikeType)) { - if (auto ptrType = - dyn_cast(tensorType.getElementType())) { + if (auto tensorType = llvm::dyn_cast(type)) { + if (tensorType.getElementType().isIntOrIndex()) { + // Tensors of offsets + // Important note: + // We only care about tensor of index / int (in addition to pointer type) + // because only values of int and index type can potentially be part of a + // pointer arithmetic sequence. + offsetSegmentSize = strideSegmentSize = tensorType.getRank(); + } else if (auto ptrType = + dyn_cast(tensorType.getElementType())) { + // Unstructured pointers (tensor>) // Each tensor of rank k gets k values for its offsets and k values for // its strides, all of which has Index type. offsetSegmentSize = strideSegmentSize = tensorType.getRank(); } } // Block pointers (!tt.ptr> or !tt.ptr) - else if (auto ptrType = llvm::dyn_cast(ptrLikeType)) { + else if (auto ptrType = llvm::dyn_cast(type)) { if (auto tensorType = llvm::dyn_cast(ptrType.getPointeeType())) { // Each tensor of rank k gets k values for its offsets and k values for diff --git a/python/examples/test_tensor_index_iterargs.py b/python/examples/test_tensor_index_iterargs.py new file mode 100644 index 00000000..6bc52d66 --- /dev/null +++ b/python/examples/test_tensor_index_iterargs.py @@ -0,0 +1,149 @@ +import torch + +import triton +import triton.language as tl + +from triton.backends.triton_shared.driver import CPUDriver + +def test_tensor_indices_nested_with_mask(device): + @triton.jit + def addptr_with_masks(in0, out0, mask_bound): + offs = tl.arange(0, 4) + out_offs = tl.arange(0, 4) + # We're loading 16 elements here, the bound is set to 14 so that + # the mask only applies to the last iteration's load + # TODO: The current mask implementation in triton-shared does not seem + # to work when the mask applies to the entire tensor load, perhaps + # the lowerings for subviews with 0-dimensions do not work? + for i in range(0, 4): + mask = offs < mask_bound + a = tl.load(in0 + offs, mask=mask, other=-11) + tl.store(out0 + out_offs, a) + offs += 4 + out_offs += 4 + + + SIZE = 17 + input = torch.arange(0, SIZE, device=device, dtype=torch.int32) + output = torch.full((SIZE,), -1, device=device, dtype=torch.int32) + + if device == 'cpu': + triton.runtime.driver.set_active(CPUDriver()) + + grid = lambda meta: (1,) + + print(output) + addptr_with_masks[grid](input, output, 14) + expected_output = torch.tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + -11, -11, -1], dtype=torch.int32, device=device) + torch.testing.assert_close(output, expected_output) + print(input) + print(output) + + +def test_tensor_indices_nested(device): + @triton.jit + def tensor_indices_nested(in0, out0): + offs = tl.arange(0, 4) + out_offs = tl.arange(0, 4) + for i in range(0, 2): + offs += i * 2 + a = tl.load(in0 + offs) + tl.store(out0 + out_offs, a) + offs += 4 + out_offs += 4 + for j in range(0, 3): + offs += j * 3 + a = tl.load(in0 + offs) + tl.store(out0 + out_offs, a) + offs += 4 + out_offs += 4 + + SIZE = 64 + input = torch.arange(0, SIZE, device=device, dtype=torch.int32) + output = torch.full((SIZE,), -1, device=device, dtype=torch.int32) + + if device == 'cpu': + triton.runtime.driver.set_active(CPUDriver()) + + grid = lambda meta: (1,) + + print(output) + tensor_indices_nested[grid](input, output) + expected_output = torch.tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 11, 12, 13, 14, 21, 22, 23, 24, 27, 28, + 29, 30, 31, 32, 33, 34, 38, 39, 40, 41, 48, 49, 50, 51, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], device=device, + dtype=torch.int32) + torch.testing.assert_close(output, expected_output) + print(input) + print(output) + +def test_integer_tensor(device): + @triton.jit + def test_1(out0): + offs = tl.arange(0, 4) + out_offs = tl.arange(0, 4) + for i in range(0, 2): + tl.store(out0 + out_offs, offs) + out_offs += 4 + offs += 4 + + + SIZE = 8 + input = torch.arange(0, SIZE, device=device, dtype=torch.int32) + output = torch.full((SIZE,), -1, device=device, dtype=torch.int32) + + if device == 'cpu': + triton.runtime.driver.set_active(CPUDriver()) + + grid = lambda meta: (1,) + + print(output) + test_1[grid](output) + print(input) + print(output) + torch.testing.assert_close(input, output) + src = triton.compiler.ASTSource( + fn=test_1, + signature="*fp32", + ) + ret = triton.compile( + src, + ) + print(ret.asm["ttir"]) + + + +def disabled_test_mask(device): + # TODO: This fails to compile in StructuredToMemref + @triton.jit + def test_1(in0, out0, batch): + offs = 4 + tl.arange(0, 4) + out_offs = tl.arange(0, 4) + a = tl.load(in0 + offs, mask=offs < 0, other=-1) + tl.store(out0 + out_offs, a) + + # TODO: This segfauls in the CPU backend + # Crashes when the batch value will mask off all of the tensors + @triton.jit + def test_2(in0, out0, batch): + offs = 4 + tl.arange(0, 4) + out_offs = tl.arange(0, 4) + a = tl.load(in0 + offs, mask=offs < 0, other=-1) + tl.store(out0 + out_offs, a) + + + SIZE = 8 + input = torch.arange(0, SIZE, device=device, dtype=torch.int32) + output = torch.full((SIZE,), -1, device=device, dtype=torch.int32) + + if device == 'cpu': + triton.runtime.driver.set_active(CPUDriver()) + + grid = lambda meta: (1,) + + print(output) + test_1[grid](input, output, 0) + print(input) + print(output) diff --git a/test/Conversion/TritonToLinalg/tensor_indices_loop_iterarg_with_masks.mlir b/test/Conversion/TritonToLinalg/tensor_indices_loop_iterarg_with_masks.mlir new file mode 100644 index 00000000..e2134d84 --- /dev/null +++ b/test/Conversion/TritonToLinalg/tensor_indices_loop_iterarg_with_masks.mlir @@ -0,0 +1,79 @@ +// RUN: triton-shared-opt --triton-to-linalg-experimental %s | FileCheck %s + +// IR from python/examples/test_tensor_index_iterargs.py +module { + tt.func public @addptr_with_masks(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %cst = arith.constant dense<-1.100000e+01> : tensor<4xf32> + %c1_i32 = arith.constant 1 : i32 + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<4> : tensor<4xi32> + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = tt.splat %arg2 : i32 -> tensor<4xi32> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %4:2 = scf.for %arg3 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg4 = %0, %arg5 = %0) -> (tensor<4xi32>, tensor<4xi32>) : i32 { + %5 = arith.cmpi slt, %arg4, %1 : tensor<4xi32> + %6 = tt.addptr %2, %arg4 : tensor<4x!tt.ptr>, tensor<4xi32> + %7 = tt.load %6, %5, %cst : tensor<4x!tt.ptr> + %8 = tt.addptr %3, %arg5 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %8, %7 : tensor<4x!tt.ptr> + %9 = arith.addi %arg4, %cst_0 : tensor<4xi32> + %10 = arith.addi %arg5, %cst_0 : tensor<4xi32> + scf.yield %9, %10 : tensor<4xi32>, tensor<4xi32> + } + tt.return + } +} + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func.func @addptr_with_masks +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : i32 +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_4_1_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_minus_1_dot_100000_:%.+]] = arith.constant -1.100000e+01 : f32 +// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_4_]] : i32) outs([[VAR_0_]] : tensor<4xi32>) -> tensor<4xi32> +// CHECK-DAG: [[VAR_2_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_0_]] : tensor<4xi32>) { +// CHECK: ^bb0([[IN_0_:%.+]]: i32): +// CHECK: [[VAR_4_:%.+]] = linalg.index 0 : index +// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[VAR_4_]] : index to i32 +// CHECK: linalg.yield [[VAR_5_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK-DAG: [[VAR_3_:%.+]]:4 = scf.for [[VAR_arg9_:%.+]] = [[CST_0_]] to [[CST_4_]] step [[CST_1_]] iter_args([[VAR_arg10_:%.+]] = [[VAR_2_]], [[VAR_arg11_:%.+]] = [[CST_0_1_]], [[VAR_arg12_:%.+]] = [[VAR_2_]], [[VAR_arg13_:%.+]] = [[CST_0_1_]]) -> (tensor<4xi32>, index, tensor<4xi32>, index) : i32 { +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_arg11_]]{{.}}, sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}} : memref<*xf32> to memref<4xf32, strided<[?], offset: ?>> +// CHECK-DAG: [[VAR_4_1_:%.+]] = arith.addi [[VAR_arg11_]], [[CST_4_1_]] : index +// CHECK-DAG: [[VAR_5_1_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK: [[VAR_6_:%.+]] = arith.minsi [[VAR_4_1_]], [[VAR_5_1_]] : index +// CHECK-DAG: [[VAR_7_:%.+]] = arith.subi [[VAR_6_]], [[VAR_arg11_]] : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4xf32> +// CHECK: [[VAR_8_:%.+]] = arith.cmpi slt, [[VAR_7_]], [[CST_4_1_]] : index +// CHECK: scf.if [[VAR_8_]] { +// CHECK: linalg.fill ins([[CST_minus_1_dot_100000_]] : f32) outs([[RES_]] : memref<4xf32>) +// CHECK: } +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_7_]]{{.}} [1] : memref<4xf32, strided<[?], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_7_]]{{.}} [1] : memref<4xf32> to memref> +// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> +// CHECK-DAG: [[VAR_9_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4xf32> +// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg13_]]{{.}}, sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}} : memref<*xf32> to memref<4xf32, strided<[?], offset: ?>> +// CHECK: bufferization.materialize_in_destination [[VAR_9_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<4xf32>, memref<4xf32, strided<[?], offset: ?>>) -> () +// CHECK: [[VAR_10_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_arg10_]], [[VAR_1_]] : tensor<4xi32>, tensor<4xi32>) outs([[VAR_arg10_]] : tensor<4xi32>) { +// CHECK: ^bb0([[IN_1_:%.+]]: i32, [[IN_2_:%.+]]: i32, [[IN_3_:%.+]]: i32): +// CHECK: [[VAR_13_:%.+]] = arith.addi [[IN_1_]], [[IN_2_]] : i32 +// CHECK: linalg.yield [[VAR_13_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK: [[VAR_11_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_arg12_]], [[VAR_1_]] : tensor<4xi32>, tensor<4xi32>) outs([[VAR_arg12_]] : tensor<4xi32>) { +// CHECK: ^bb0([[IN_4_:%.+]]: i32, [[IN_5_:%.+]]: i32, [[IN_6_:%.+]]: i32): +// CHECK: [[VAR_13_1_:%.+]] = arith.addi [[IN_4_]], [[IN_5_]] : i32 +// CHECK: linalg.yield [[VAR_13_1_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK: [[VAR_12_:%.+]] = arith.addi [[VAR_arg13_]], [[CST_4_1_]] : index +// CHECK: scf.yield [[VAR_10_]], [[VAR_4_1_]], [[VAR_11_]], [[VAR_12_]] : tensor<4xi32>, index, tensor<4xi32>, index +// CHECK: } +// CHECK: return +// CHECK: } diff --git a/test/Conversion/TritonToLinalg/tensor_indices_loop_iterargs_nested.mlir b/test/Conversion/TritonToLinalg/tensor_indices_loop_iterargs_nested.mlir new file mode 100644 index 00000000..60b6e5b1 --- /dev/null +++ b/test/Conversion/TritonToLinalg/tensor_indices_loop_iterargs_nested.mlir @@ -0,0 +1,126 @@ +// RUN: triton-shared-opt --triton-to-linalg-experimental %s | FileCheck %s + +// IR from python/examples/test_tensor_index_iterargs.py +module { + tt.func public @tensor_indices_nested(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c3_i32 = arith.constant 3 : i32 + %cst = arith.constant dense<4> : tensor<4xi32> + %c2_i32 = arith.constant 2 : i32 + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %2 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %0, %arg4 = %0) -> (tensor<4xi32>, tensor<4xi32>) : i32 { + %4 = arith.muli %arg2, %c2_i32 : i32 + %5 = tt.splat %4 : i32 -> tensor<4xi32> + %6 = arith.addi %arg3, %5 : tensor<4xi32> + %7 = tt.addptr %1, %6 : tensor<4x!tt.ptr>, tensor<4xi32> + %8 = tt.load %7 : tensor<4x!tt.ptr> + %9 = tt.addptr %2, %arg4 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %9, %8 : tensor<4x!tt.ptr> + %10 = arith.addi %6, %cst : tensor<4xi32> + %11 = arith.addi %arg4, %cst : tensor<4xi32> + %12:2 = scf.for %arg5 = %c0_i32 to %c3_i32 step %c1_i32 iter_args(%arg6 = %10, %arg7 = %11) -> (tensor<4xi32>, tensor<4xi32>) : i32 { + %13 = arith.muli %arg5, %c3_i32 : i32 + %14 = tt.splat %13 : i32 -> tensor<4xi32> + %15 = arith.addi %arg6, %14 : tensor<4xi32> + %16 = tt.addptr %1, %15 : tensor<4x!tt.ptr>, tensor<4xi32> + %17 = tt.load %16 : tensor<4x!tt.ptr> + %18 = tt.addptr %2, %arg7 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %18, %17 : tensor<4x!tt.ptr> + %19 = arith.addi %15, %cst : tensor<4xi32> + %20 = arith.addi %arg7, %cst : tensor<4xi32> + scf.yield %19, %20 : tensor<4xi32>, tensor<4xi32> + } + scf.yield %12#0, %12#1 : tensor<4xi32>, tensor<4xi32> + } + tt.return + } +} + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func.func @tensor_indices_nested +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : i32 +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : i32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_4_1_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_4_]] : i32) outs([[VAR_0_]] : tensor<4xi32>) -> tensor<4xi32> +// CHECK-DAG: [[VAR_2_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_0_]] : tensor<4xi32>) { +// CHECK: ^bb0([[IN_0_:%.+]]: i32): +// CHECK: [[VAR_4_:%.+]] = linalg.index 0 : index +// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[VAR_4_]] : index to i32 +// CHECK: linalg.yield [[VAR_5_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK-DAG: [[VAR_3_:%.+]]:4 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg9_:%.+]] = [[VAR_2_]], [[VAR_arg10_:%.+]] = [[CST_0_1_]], [[VAR_arg11_:%.+]] = [[VAR_2_]], [[VAR_arg12_:%.+]] = [[CST_0_1_]]) -> (tensor<4xi32>, index, tensor<4xi32>, index) : i32 { +// CHECK-DAG: [[VAR_4_1_:%.+]] = arith.muli [[VAR_arg8_]], [[CST_2_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_5_1_:%.+]] = arith.index_cast [[VAR_4_1_]] : i32 to index +// CHECK-DAG: [[VAR_6_:%.+]] = linalg.fill ins([[VAR_4_1_]] : i32) outs([[VAR_0_]] : tensor<4xi32>) -> tensor<4xi32> +// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_arg9_]], [[VAR_6_]] : tensor<4xi32>, tensor<4xi32>) outs([[VAR_arg9_]] : tensor<4xi32>) { +// CHECK: ^bb0([[IN_1_:%.+]]: i32, [[IN_2_:%.+]]: i32, [[IN_3_:%.+]]: i32): +// CHECK: [[VAR_15_:%.+]] = arith.addi [[IN_1_]], [[IN_2_]] : i32 +// CHECK: linalg.yield [[VAR_15_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_arg10_]], [[VAR_5_1_]] : index +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_8_]]{{.}}, sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}} : memref<*xf32> to memref<4xf32, strided<[?], offset: ?>> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4xf32> +// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4xf32, strided<[?], offset: ?>> to memref<4xf32> +// CHECK-DAG: [[VAR_9_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4xf32> +// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg12_]]{{.}}, sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}} : memref<*xf32> to memref<4xf32, strided<[?], offset: ?>> +// CHECK: bufferization.materialize_in_destination [[VAR_9_]] in writable [[VAR_reinterpret_cast_0_]] : (tensor<4xf32>, memref<4xf32, strided<[?], offset: ?>>) -> () +// CHECK: [[VAR_10_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]], [[VAR_1_]] : tensor<4xi32>, tensor<4xi32>) outs([[VAR_7_]] : tensor<4xi32>) { +// CHECK: ^bb0([[IN_4_:%.+]]: i32, [[IN_5_:%.+]]: i32, [[IN_6_:%.+]]: i32): +// CHECK: [[VAR_15_1_:%.+]] = arith.addi [[IN_4_]], [[IN_5_]] : i32 +// CHECK: linalg.yield [[VAR_15_1_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK: [[VAR_11_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_arg11_]], [[VAR_1_]] : tensor<4xi32>, tensor<4xi32>) outs([[VAR_arg11_]] : tensor<4xi32>) { +// CHECK: ^bb0([[IN_7_:%.+]]: i32, [[IN_8_:%.+]]: i32, [[IN_9_:%.+]]: i32): +// CHECK: [[VAR_15_2_:%.+]] = arith.addi [[IN_7_]], [[IN_8_]] : i32 +// CHECK: linalg.yield [[VAR_15_2_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK-DAG: [[VAR_12_:%.+]] = arith.addi [[VAR_8_]], [[CST_4_1_]] : index +// CHECK-DAG: [[VAR_13_:%.+]] = arith.addi [[VAR_arg12_]], [[CST_4_1_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_14_:%.+]]:4 = scf.for [[VAR_arg13_:%.+]] = [[CST_0_]] to [[CST_3_]] step [[CST_1_]] iter_args([[VAR_arg14_:%.+]] = [[VAR_10_]], [[VAR_arg15_:%.+]] = [[VAR_12_]], [[VAR_arg16_:%.+]] = [[VAR_11_]], [[VAR_arg17_:%.+]] = [[VAR_13_]]) -> (tensor<4xi32>, index, tensor<4xi32>, index) : i32 { +// CHECK-DAG: [[VAR_15_3_:%.+]] = arith.muli [[VAR_arg13_]], [[CST_3_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_16_:%.+]] = arith.index_cast [[VAR_15_3_]] : i32 to index +// CHECK-DAG: [[VAR_17_:%.+]] = linalg.fill ins([[VAR_15_3_]] : i32) outs([[VAR_0_]] : tensor<4xi32>) -> tensor<4xi32> +// CHECK: [[VAR_18_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_arg14_]], [[VAR_17_]] : tensor<4xi32>, tensor<4xi32>) outs([[VAR_arg14_]] : tensor<4xi32>) { +// CHECK: ^bb0([[IN_10_:%.+]]: i32, [[IN_11_:%.+]]: i32, [[IN_12_:%.+]]: i32): +// CHECK: [[VAR_25_:%.+]] = arith.addi [[IN_10_]], [[IN_11_]] : i32 +// CHECK: linalg.yield [[VAR_25_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK: [[VAR_19_:%.+]] = arith.addi [[VAR_arg15_]], [[VAR_16_]] : index +// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_19_]]{{.}}, sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}} : memref<*xf32> to memref<4xf32, strided<[?], offset: ?>> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<4xf32> +// CHECK: memref.copy [[VAR_reinterpret_cast_1_]], [[RES_1_]] : memref<4xf32, strided<[?], offset: ?>> to memref<4xf32> +// CHECK-DAG: [[VAR_20_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<4xf32> +// CHECK-DAG: [[VAR_reinterpret_cast_3_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg17_]]{{.}}, sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}} : memref<*xf32> to memref<4xf32, strided<[?], offset: ?>> +// CHECK: bufferization.materialize_in_destination [[VAR_20_]] in writable [[VAR_reinterpret_cast_3_]] : (tensor<4xf32>, memref<4xf32, strided<[?], offset: ?>>) -> () +// CHECK: [[VAR_21_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_18_]], [[VAR_1_]] : tensor<4xi32>, tensor<4xi32>) outs([[VAR_18_]] : tensor<4xi32>) { +// CHECK: ^bb0([[IN_13_:%.+]]: i32, [[IN_14_:%.+]]: i32, [[IN_15_:%.+]]: i32): +// CHECK: [[VAR_25_1_:%.+]] = arith.addi [[IN_13_]], [[IN_14_]] : i32 +// CHECK: linalg.yield [[VAR_25_1_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK: [[VAR_22_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_arg16_]], [[VAR_1_]] : tensor<4xi32>, tensor<4xi32>) outs([[VAR_arg16_]] : tensor<4xi32>) { +// CHECK: ^bb0([[IN_16_:%.+]]: i32, [[IN_17_:%.+]]: i32, [[IN_18_:%.+]]: i32): +// CHECK: [[VAR_25_2_:%.+]] = arith.addi [[IN_16_]], [[IN_17_]] : i32 +// CHECK: linalg.yield [[VAR_25_2_]] : i32 +// CHECK: } -> tensor<4xi32> +// CHECK-DAG: [[VAR_23_:%.+]] = arith.addi [[VAR_19_]], [[CST_4_1_]] : index +// CHECK-DAG: [[VAR_24_:%.+]] = arith.addi [[VAR_arg17_]], [[CST_4_1_]] : index +// CHECK: scf.yield [[VAR_21_]], [[VAR_23_]], [[VAR_22_]], [[VAR_24_]] : tensor<4xi32>, index, tensor<4xi32>, index +// CHECK: } +// CHECK: scf.yield [[VAR_14_]]#0, [[VAR_14_]]#1, [[VAR_14_]]#2, [[VAR_14_]]#3 : tensor<4xi32>, index, tensor<4xi32>, index +// CHECK: } +// CHECK: return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/tensor_indices_loop_iterarg_with_masks.mlir b/test/Conversion/TritonToStructured/tensor_indices_loop_iterarg_with_masks.mlir new file mode 100644 index 00000000..3c96c2d3 --- /dev/null +++ b/test/Conversion/TritonToStructured/tensor_indices_loop_iterarg_with_masks.mlir @@ -0,0 +1,51 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize --cse %s | FileCheck %s + +// IR from python/examples/test_tensor_index_iterargs.py +module { + tt.func public @addptr_with_masks(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %cst = arith.constant dense<-1.100000e+01> : tensor<4xf32> + %c1_i32 = arith.constant 1 : i32 + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<4> : tensor<4xi32> + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = tt.splat %arg2 : i32 -> tensor<4xi32> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %4:2 = scf.for %arg3 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg4 = %0, %arg5 = %0) -> (tensor<4xi32>, tensor<4xi32>) : i32 { + %5 = arith.cmpi slt, %arg4, %1 : tensor<4xi32> + %6 = tt.addptr %2, %arg4 : tensor<4x!tt.ptr>, tensor<4xi32> + %7 = tt.load %6, %5, %cst : tensor<4x!tt.ptr> + %8 = tt.addptr %3, %arg5 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %8, %7 : tensor<4x!tt.ptr> + %9 = arith.addi %arg4, %cst_0 : tensor<4xi32> + %10 = arith.addi %arg5, %cst_0 : tensor<4xi32> + scf.yield %9, %10 : tensor<4xi32>, tensor<4xi32> + } + tt.return + } +} + +// CHECK: tt.func public @addptr_with_masks([[arg0_:.+]]: !tt.ptr, [[arg1_:.+]]: !tt.ptr, [[arg2_:.+]]: i32) attributes {noinline = false} { +// CHECK-DAG: [[CST_minus_1_dot_100000_:%.+]] = arith.constant -1.100000e+01 : f32 +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_4_1_:%.+]] = arith.constant 4 : i32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_0_:%.+]]:2 = scf.for [[VAR_arg3_:%.+]] = [[CST_0_]] to [[CST_4_1_]] step [[CST_1_]] iter_args([[VAR_arg4_:%.+]] = [[CST_0_1_]], [[VAR_arg5_:%.+]] = [[CST_0_1_]]) -> (index, index) : i32 { +// CHECK-DAG: [[VAR_1_:%.+]] = tts.make_tptr [[arg0_]] to sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}}, offsets: {{.}}[[VAR_arg4_]]{{.}}, shape: [0], order: [] : to tensor<4x!tt.ptr> +// CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_arg4_]], [[CST_4_]] : index +// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[arg2_]] : i32 to index +// CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_arg4_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = "tts.load"([[VAR_1_]], [[VAR_5_]], [[CST_minus_1_dot_100000_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<4x!tt.ptr>, index, f32) -> tensor<4xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = tts.make_tptr [[arg1_]] to sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}}, offsets: {{.}}[[VAR_arg5_]]{{.}}, shape: [0], order: [] : to tensor<4x!tt.ptr> +// CHECK: "tts.store"([[VAR_7_]], [[VAR_6_]]) <{static_mask_dims = array}> : (tensor<4x!tt.ptr>, tensor<4xf32>) -> () +// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_arg5_]], [[CST_4_]] : index +// CHECK: scf.yield [[VAR_2_]], [[VAR_8_]] : index, index +// CHECK: } +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/tensor_indices_loop_iterargs_nested.mlir b/test/Conversion/TritonToStructured/tensor_indices_loop_iterargs_nested.mlir new file mode 100644 index 00000000..d868714c --- /dev/null +++ b/test/Conversion/TritonToStructured/tensor_indices_loop_iterargs_nested.mlir @@ -0,0 +1,77 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize --cse %s | FileCheck %s + +// IR from python/examples/test_tensor_index_iterargs.py +module { + tt.func public @tensor_indices_nested(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c3_i32 = arith.constant 3 : i32 + %cst = arith.constant dense<4> : tensor<4xi32> + %c2_i32 = arith.constant 2 : i32 + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %2 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %0, %arg4 = %0) -> (tensor<4xi32>, tensor<4xi32>) : i32 { + %4 = arith.muli %arg2, %c2_i32 : i32 + %5 = tt.splat %4 : i32 -> tensor<4xi32> + %6 = arith.addi %arg3, %5 : tensor<4xi32> + %7 = tt.addptr %1, %6 : tensor<4x!tt.ptr>, tensor<4xi32> + %8 = tt.load %7 : tensor<4x!tt.ptr> + %9 = tt.addptr %2, %arg4 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %9, %8 : tensor<4x!tt.ptr> + %10 = arith.addi %6, %cst : tensor<4xi32> + %11 = arith.addi %arg4, %cst : tensor<4xi32> + %12:2 = scf.for %arg5 = %c0_i32 to %c3_i32 step %c1_i32 iter_args(%arg6 = %10, %arg7 = %11) -> (tensor<4xi32>, tensor<4xi32>) : i32 { + %13 = arith.muli %arg5, %c3_i32 : i32 + %14 = tt.splat %13 : i32 -> tensor<4xi32> + %15 = arith.addi %arg6, %14 : tensor<4xi32> + %16 = tt.addptr %1, %15 : tensor<4x!tt.ptr>, tensor<4xi32> + %17 = tt.load %16 : tensor<4x!tt.ptr> + %18 = tt.addptr %2, %arg7 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %18, %17 : tensor<4x!tt.ptr> + %19 = arith.addi %15, %cst : tensor<4xi32> + %20 = arith.addi %arg7, %cst : tensor<4xi32> + scf.yield %19, %20 : tensor<4xi32>, tensor<4xi32> + } + scf.yield %12#0, %12#1 : tensor<4xi32>, tensor<4xi32> + } + tt.return + } +} + +// CHECK: tt.func public @tensor_indices_nested([[arg0_:.+]]: !tt.ptr, [[arg1_:.+]]: !tt.ptr) attributes {noinline = false} { +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : i32 +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_0_:%.+]]:2 = scf.for [[VAR_arg2_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg3_:%.+]] = [[CST_0_1_]], [[VAR_arg4_:%.+]] = [[CST_0_1_]]) -> (index, index) : i32 { +// CHECK-DAG: [[VAR_1_:%.+]] = arith.muli [[VAR_arg2_]], [[CST_2_]] : i32 +// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK: [[VAR_3_:%.+]] = arith.addi [[VAR_arg3_]], [[VAR_2_]] : index +// CHECK: [[VAR_4_:%.+]] = tts.make_tptr [[arg0_]] to sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}}, offsets: {{.}}[[VAR_3_]]{{.}}, shape: [0], order: [] : to tensor<4x!tt.ptr> +// CHECK-DAG: [[VAR_5_:%.+]] = "tts.load"([[VAR_4_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<4x!tt.ptr>) -> tensor<4xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = tts.make_tptr [[arg1_]] to sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}}, offsets: {{.}}[[VAR_arg4_]]{{.}}, shape: [0], order: [] : to tensor<4x!tt.ptr> +// CHECK: "tts.store"([[VAR_6_]], [[VAR_5_]]) <{static_mask_dims = array}> : (tensor<4x!tt.ptr>, tensor<4xf32>) -> () +// CHECK-DAG: [[VAR_7_:%.+]] = arith.addi [[VAR_3_]], [[CST_4_]] : index +// CHECK-DAG: [[VAR_8_:%.+]] = arith.addi [[VAR_arg4_]], [[CST_4_]] : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_9_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_3_]] step [[CST_1_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_7_]], [[VAR_arg7_:%.+]] = [[VAR_8_]]) -> (index, index) : i32 { +// CHECK-DAG: [[VAR_10_:%.+]] = arith.muli [[VAR_arg5_]], [[CST_3_]] : i32 +// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[VAR_10_]] : i32 to index +// CHECK: [[VAR_12_:%.+]] = arith.addi [[VAR_arg6_]], [[VAR_11_]] : index +// CHECK: [[VAR_13_:%.+]] = tts.make_tptr [[arg0_]] to sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}}, offsets: {{.}}[[VAR_12_]]{{.}}, shape: [0], order: [] : to tensor<4x!tt.ptr> +// CHECK-DAG: [[VAR_14_:%.+]] = "tts.load"([[VAR_13_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<4x!tt.ptr>) -> tensor<4xf32> +// CHECK-DAG: [[VAR_15_:%.+]] = tts.make_tptr [[arg1_]] to sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}}, offsets: {{.}}[[VAR_arg7_]]{{.}}, shape: [0], order: [] : to tensor<4x!tt.ptr> +// CHECK: "tts.store"([[VAR_15_]], [[VAR_14_]]) <{static_mask_dims = array}> : (tensor<4x!tt.ptr>, tensor<4xf32>) -> () +// CHECK-DAG: [[VAR_16_:%.+]] = arith.addi [[VAR_12_]], [[CST_4_]] : index +// CHECK-DAG: [[VAR_17_:%.+]] = arith.addi [[VAR_arg7_]], [[CST_4_]] : index +// CHECK: scf.yield [[VAR_16_]], [[VAR_17_]] : index, index +// CHECK: } +// CHECK: scf.yield [[VAR_9_]]#0, [[VAR_9_]]#1 : index, index +// CHECK: } +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/tensor_indices_loop_iterargs_not_used_ptranalysis_e2e.mlir b/test/Conversion/TritonToStructured/tensor_indices_loop_iterargs_not_used_ptranalysis_e2e.mlir new file mode 100644 index 00000000..26732ed7 --- /dev/null +++ b/test/Conversion/TritonToStructured/tensor_indices_loop_iterargs_not_used_ptranalysis_e2e.mlir @@ -0,0 +1,43 @@ +// IR obtained from "test_integer_tensor" in python/examples/test_tensor_index_iterargs.py + +// RUN: triton-shared-opt --triton-to-structured --cse --canonicalize --remove-dead-values %s | FileCheck %s +module { + tt.func public @test_1(%arg0: !tt.ptr) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<4> : tensor<4xi32> + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %2:2 = scf.for %arg1 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg2 = %0, %arg3 = %0) -> (tensor<4xi32>, tensor<4xi32>) : i32 { + %3 = tt.addptr %1, %arg2 : tensor<4x!tt.ptr>, tensor<4xi32> + %4 = arith.sitofp %arg3 : tensor<4xi32> to tensor<4xf32> + tt.store %3, %4 : tensor<4x!tt.ptr> + %5 = arith.addi %arg2, %cst : tensor<4xi32> + %6 = arith.addi %arg3, %cst : tensor<4xi32> + scf.yield %5, %6 : tensor<4xi32>, tensor<4xi32> + } + tt.return + } +} + +// CHECK: tt.func public @test_1([[arg0_:.+]]: !tt.ptr) attributes {noinline = false} { +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<4> : tensor<4xi32> +// CHECK-DAG: [[VAR_0_:%.+]] = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]]:2 = scf.for [[VAR_arg1_:%.+]] = [[CST_0_1_]] to [[CST_2_]] step [[CST_1_1_]] iter_args([[VAR_arg2_:%.+]] = [[CST_0_]], [[VAR_arg3_:%.+]] = [[VAR_0_]]) -> (index, tensor<4xi32>) : i32 { +// CHECK-DAG: [[VAR_2_:%.+]] = tts.make_tptr [[arg0_]] to sizes: [4], strides: {{.}}[[CST_1_]]{{.}}, offsets: {{.}}[[VAR_arg2_]]{{.}}, shape: [0], order: [] : to tensor<4x!tt.ptr> +// CHECK-DAG: [[VAR_3_:%.+]] = arith.sitofp [[VAR_arg3_]] : tensor<4xi32> to tensor<4xf32> +// CHECK: "tts.store"([[VAR_2_]], [[VAR_3_]]) <{static_mask_dims = array}> : (tensor<4x!tt.ptr>, tensor<4xf32>) -> () +// CHECK-DAG: [[VAR_4_:%.+]] = arith.addi [[VAR_arg3_]], [[VAR_cst_]] : tensor<4xi32> +// CHECK-DAG: [[VAR_5_:%.+]] = arith.addi [[VAR_arg2_]], [[CST_4_]] : index +// CHECK: scf.yield [[VAR_5_]], [[VAR_4_]] : index, tensor<4xi32> +// CHECK: } +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/tensor_indices_loop_iterargs_not_used_ptranalysis_prepass.mlir b/test/Conversion/TritonToStructured/tensor_indices_loop_iterargs_not_used_ptranalysis_prepass.mlir new file mode 100644 index 00000000..5219da88 --- /dev/null +++ b/test/Conversion/TritonToStructured/tensor_indices_loop_iterargs_not_used_ptranalysis_prepass.mlir @@ -0,0 +1,44 @@ +// IR obtained from "test_integer_tensor" in python/examples/test_tensor_index_iterargs.py + +// RUN: triton-shared-opt --triton-to-structured="run-prepass-only" %s | FileCheck %s +module { + tt.func public @test_1(%arg0: !tt.ptr) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<4> : tensor<4xi32> + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %2:2 = scf.for %arg1 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg2 = %0, %arg3 = %0) -> (tensor<4xi32>, tensor<4xi32>) : i32 { + %3 = tt.addptr %1, %arg2 : tensor<4x!tt.ptr>, tensor<4xi32> + %4 = arith.sitofp %arg3 : tensor<4xi32> to tensor<4xf32> + tt.store %3, %4 : tensor<4x!tt.ptr> + %5 = arith.addi %arg2, %cst : tensor<4xi32> + %6 = arith.addi %arg3, %cst : tensor<4xi32> + scf.yield %5, %6 : tensor<4xi32>, tensor<4xi32> + } + tt.return + } +} + +// CHECK: tt.func public @test_1([[arg0_:.+]]: !tt.ptr) attributes {noinline = false} { +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<4> : tensor<4xi32> +// CHECK-DAG: [[VAR_0_:%.+]] = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> +// CHECK-DAG: [[VAR_1_:%.+]] = tt.splat [[arg0_]] : !tt.ptr -> tensor<4x!tt.ptr> +// CHECK: [[structured_:%.+]], [[offsets_:%.+]], [[VAR_strides_:%.+]] = "tts.get_structured_state"([[VAR_0_]]) <{resultSegmentSizes = array}> : (tensor<4xi32>) -> (tensor<4xi32>, index, index) +// CHECK: [[structured_0_:%.+]], [[offsets_1_:%.+]], [[VAR_strides_2_:%.+]] = "tts.get_structured_state"([[VAR_0_]]) <{resultSegmentSizes = array}> : (tensor<4xi32>) -> (tensor<4xi32>, index, index) +// CHECK-DAG: [[VAR_2_:%.+]]:6 = scf.for [[VAR_arg1_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg2_:%.+]] = [[structured_]], [[VAR_arg3_:%.+]] = [[offsets_]], [[VAR_arg4_:%.+]] = [[VAR_strides_]], [[VAR_arg5_:%.+]] = [[structured_0_]], [[VAR_arg6_:%.+]] = [[offsets_1_]], [[VAR_arg7_:%.+]] = [[VAR_strides_2_]]) -> (tensor<4xi32>, index, index, tensor<4xi32>, index, index) : i32 { +// CHECK-DAG: [[VAR_3_:%.+]] = tt.addptr [[VAR_1_]], [[VAR_arg2_]] : tensor<4x!tt.ptr>, tensor<4xi32> +// CHECK-DAG: [[VAR_4_:%.+]] = arith.sitofp [[VAR_arg5_]] : tensor<4xi32> to tensor<4xf32> +// CHECK: tt.store [[VAR_3_]], [[VAR_4_]] : tensor<4x!tt.ptr> +// CHECK-DAG: [[VAR_5_:%.+]] = arith.addi [[VAR_arg2_]], [[VAR_cst_]] : tensor<4xi32> +// CHECK-DAG: [[VAR_6_:%.+]] = arith.addi [[VAR_arg5_]], [[VAR_cst_]] : tensor<4xi32> +// CHECK: [[structured_3_:%.+]], [[offsets_4_:%.+]], [[VAR_strides_5_:%.+]] = "tts.get_structured_state"([[VAR_5_]]) <{resultSegmentSizes = array}> : (tensor<4xi32>) -> (tensor<4xi32>, index, index) +// CHECK: [[structured_6_:%.+]], [[offsets_7_:%.+]], [[VAR_strides_8_:%.+]] = "tts.get_structured_state"([[VAR_6_]]) <{resultSegmentSizes = array}> : (tensor<4xi32>) -> (tensor<4xi32>, index, index) +// CHECK: scf.yield [[structured_3_]], [[offsets_4_]], [[VAR_strides_5_]], [[structured_6_]], [[offsets_7_]], [[VAR_strides_8_]] : tensor<4xi32>, index, index, tensor<4xi32>, index, index +// CHECK: } +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToStructured/wraparound_unsupported_add_offset.mlir b/test/Conversion/TritonToStructured/wraparound_unsupported_add_offset.mlir index d62898fc..f60b4a20 100644 --- a/test/Conversion/TritonToStructured/wraparound_unsupported_add_offset.mlir +++ b/test/Conversion/TritonToStructured/wraparound_unsupported_add_offset.mlir @@ -57,10 +57,11 @@ module { } } -// CHECK: tt.func public @wrap_side_by_side_masked_loop_01234567([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { +// CHECK: tt.func public @wrap_side_by_side_masked_loop_01234567([[arg0_:.+]]: !tt.ptr, [[arg1_:.+]]: !tt.ptr, [[arg2_:.+]]: i32, [[arg3_:.+]]: i32, [[arg4_:.+]]: i32, [[arg5_:.+]]: i32, [[arg6_:.+]]: i32, [[arg7_:.+]]: i32) { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<-9.900000e+01> : tensor<4x4xf32> // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 // CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<2> : tensor<4x1xi32> // CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<6> : tensor<4xi32> @@ -69,43 +70,41 @@ module { // CHECK-DAG: [[VAR_0_:%.+]] = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_1_:%.+]] = arith.addi [[VAR_0_]], [[VAR_cst_2_]] : tensor<4xi32> -// CHECK-DAG: [[VAR_2_:%.+]] = tt.splat [[PARAM_3_]] : i32 -> tensor<4xi32> +// CHECK-DAG: [[VAR_2_:%.+]] = tt.splat [[arg3_]] : i32 -> tensor<4xi32> // CHECK: [[VAR_3_:%.+]] = arith.remsi [[VAR_0_]], [[VAR_2_]] : tensor<4xi32> // CHECK-DAG: [[VAR_4_:%.+]] = arith.addi [[VAR_3_]], [[VAR_cst_1_]] : tensor<4xi32> // CHECK-DAG: [[VAR_5_:%.+]] = tt.expand_dims [[VAR_1_]] {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> -// CHECK-DAG: [[VAR_6_:%.+]] = tt.splat [[PARAM_4_]] : i32 -> tensor<4x1xi32> +// CHECK-DAG: [[VAR_6_:%.+]] = tt.splat [[arg4_]] : i32 -> tensor<4x1xi32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_7_:%.+]] = arith.muli [[VAR_5_]], [[VAR_6_]] : tensor<4x1xi32> // CHECK-DAG: [[VAR_8_:%.+]] = tt.expand_dims [[VAR_4_]] {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> -// CHECK-DAG: [[VAR_9_:%.+]] = tt.splat [[PARAM_5_]] : i32 -> tensor<1x4xi32> +// CHECK-DAG: [[VAR_9_:%.+]] = tt.splat [[arg5_]] : i32 -> tensor<1x4xi32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_10_:%.+]] = arith.muli [[VAR_8_]], [[VAR_9_]] : tensor<1x4xi32> // CHECK-DAG: [[VAR_11_:%.+]] = tt.broadcast [[VAR_7_]] : tensor<4x1xi32> -> tensor<4x4xi32> // CHECK: [[VAR_12_:%.+]] = tt.broadcast [[VAR_10_]] : tensor<1x4xi32> -> tensor<4x4xi32> // CHECK-DAG: [[VAR_13_:%.+]] = arith.addi [[VAR_11_]], [[VAR_12_]] : tensor<4x4xi32> -// CHECK-DAG: [[VAR_14_:%.+]] = tt.splat [[PARAM_0_]] : !tt.ptr -> tensor<4x4x!tt.ptr> +// CHECK-DAG: [[VAR_14_:%.+]] = tt.splat [[arg0_]] : !tt.ptr -> tensor<4x4x!tt.ptr> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_15_:%.+]] = tt.addptr [[VAR_14_]], [[VAR_13_]] : tensor<4x4x!tt.ptr>, tensor<4x4xi32> // CHECK-DAG: [[VAR_16_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> -// CHECK-DAG: [[VAR_17_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index -// CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index +// CHECK-DAG: [[VAR_17_:%.+]] = arith.index_cast [[arg6_]] : i32 to index +// CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[arg7_]] : i32 to index +// CHECK: [[VAR_19_:%.+]] = arith.cmpi slt, [[VAR_16_]], [[VAR_cst_0_]] : tensor<4x1xi32> +// CHECK-DAG: [[VAR_20_:%.+]] = tt.broadcast [[VAR_19_]] : tensor<4x1xi1> -> tensor<4x4xi1> +// CHECK-DAG: [[VAR_21_:%.+]] = arith.muli [[arg4_]], [[CST_4_]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_19_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [4, 4], strides: {{.}}[[VAR_17_]], [[VAR_18_]]{{.}}, offsets: [0, 0], shape: [0, 0], order: [] : to tensor<4x4x!tt.ptr> -// CHECK-DAG: [[VAR_20_:%.+]] = arith.cmpi slt, [[VAR_16_]], [[VAR_cst_0_]] : tensor<4x1xi32> +// CHECK-DAG: [[VAR_22_:%.+]] = tt.splat [[VAR_21_]] : i32 -> tensor<4x4xi32> +// CHECK-DAG: [[VAR_23_:%.+]] = arith.muli [[arg5_]], [[CST_4_]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_21_:%.+]] = tt.broadcast [[VAR_20_]] : tensor<4x1xi1> -> tensor<4x4xi1> -// CHECK-DAG: [[VAR_22_:%.+]] = arith.muli [[PARAM_4_]], [[CST_4_]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_23_:%.+]] = tt.splat [[VAR_22_]] : i32 -> tensor<4x4xi32> -// CHECK-DAG: [[VAR_24_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_25_:%.+]] = tt.splat [[VAR_24_]] : i32 -> tensor<4x4xi32> -// CHECK-DAG: [[VAR_26_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg9_:%.+]] = [[VAR_15_]], [[VAR_arg10_:%.+]] = [[VAR_19_]]) -> (tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr>) : i32 { -// CHECK-DAG: [[LOAD_VAR_arg9_MEM_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_21_]], [[VAR_cst_]] : tensor<4x4x!tt.ptr> -// CHECK: tt.store [[VAR_arg10_]], [[LOAD_VAR_arg9_MEM_]] : tensor<4x4x!tt.ptr> -// CHECK-DAG: [[VAR_28_:%.+]] = tt.addptr [[VAR_arg9_]], [[VAR_23_]] : tensor<4x4x!tt.ptr>, tensor<4x4xi32> -// CHECK-DAG: [[VAR_29_:%.+]] = tt.addptr [[VAR_arg10_]], [[VAR_25_]] : tensor<4x4x!tt.ptr>, tensor<4x4xi32> -// CHECK: scf.yield [[VAR_28_]], [[VAR_29_]] : tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr> +// CHECK-DAG: [[VAR_24_:%.+]] = arith.index_cast [[VAR_23_]] : i32 to index +// CHECK-DAG: [[VAR_25_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_1_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg9_:%.+]] = [[VAR_15_]], [[VAR_arg10_:%.+]] = [[CST_0_]]) -> (tensor<4x4x!tt.ptr>, index) : i32 { +// CHECK-DAG: [[VAR_26_:%.+]] = tts.make_tptr [[arg1_]] to sizes: [4, 4], strides: {{.}}[[VAR_17_]], [[VAR_18_]]{{.}}, offsets: {{.}}[[VAR_arg10_]], [[CST_0_]]{{.}}, shape: [0, 0], order: [] : to tensor<4x4x!tt.ptr> +// CHECK-DAG: [[LOAD_VAR_arg9_MEM_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_20_]], [[VAR_cst_]] : tensor<4x4x!tt.ptr> +// CHECK: "tts.store"([[VAR_26_]], [[LOAD_VAR_arg9_MEM_]]) <{static_mask_dims = array}> : (tensor<4x4x!tt.ptr>, tensor<4x4xf32>) -> () +// CHECK-DAG: [[VAR_28_:%.+]] = tt.addptr [[VAR_arg9_]], [[VAR_22_]] : tensor<4x4x!tt.ptr>, tensor<4x4xi32> +// CHECK-DAG: [[VAR_29_:%.+]] = arith.addi [[VAR_arg10_]], [[VAR_24_]] : index +// CHECK: scf.yield [[VAR_28_]], [[VAR_29_]] : tensor<4x4x!tt.ptr>, index // CHECK: } // CHECK: tt.return // CHECK: }