Skip to content

Commit

Permalink
Fix mask analysis for when the entire tensor is masked off (#186)
Browse files Browse the repository at this point in the history
The current formula for computing masks does not work when the mask
bound is smaller than the start of the mask range:

```
---|-------|-----------|
   ^       ^           ^
bound    start        end
```

Current formula:

```
new_end = min(end, bound)
new_dim = new_end - start
```

For the above case, this formula will produce a negative `new_dim`. To
fix this issue, we optionally move `new_end` back to `start` so that
when `bound < start`, `new_dim` is 0.

The new formula is:

```
new_end_tmp = min(end, bound)
new_end = max(new_end_tmp, start)
new_dim = new_end - start
```

Another formula that could work in theory is to do:

```
new_end = min(end, bound)
new_dim_potentially_neg = new_end - start
new_dim = max(new_dim_potentially_neg, 0)
```

But this approach does not work in MaskAnalysis because we operate on
the `index` type which is unsigned. We would have a negative overflow
when computing `new_dim_potentially_neg` and end up getting a positive
number instead.

# Changes
+ Update the formula
+ The change is quite invasive, so I added a flag in cases we don't want
to enable this fix
+ Update lit tests
+ Removed some of the old TritonToLinalg tests; we will remove the old
pass in a future PR
  • Loading branch information
nhat-nguyen authored Oct 25, 2024
1 parent 177a624 commit d8c8f29
Show file tree
Hide file tree
Showing 38 changed files with 819 additions and 1,874 deletions.
3 changes: 3 additions & 0 deletions include/triton-shared/Analysis/MaskAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ struct MaskState {
OpFoldResult end;
SmallVector<OpFoldResult> dims;
OpFoldResult scalar;
const bool useUnsafeMask;

MaskState(bool useUnsafeMask = false) : useUnsafeMask(useUnsafeMask) {}

int64_t getRank() const { return dims.size(); }

Expand Down
2 changes: 2 additions & 0 deletions include/triton-shared/Analysis/OpFoldResultUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ OpFoldResult mulOFRValue(const OpFoldResult lhs, const Value rhs,
OpFoldResult minOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
const Location loc, OpBuilder &b);

OpFoldResult maxOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
const Location loc, OpBuilder &b);
} // namespace mlir

#endif
6 changes: 3 additions & 3 deletions include/triton-shared/AnalysisStructured/PtrAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,11 @@ class PtrAnalysis {
// strides, offsets, and modulos.
LogicalResult rewriteForOp(scf::ForOp op);

LogicalResult rewriteLoadOp(triton::LoadOp op);
LogicalResult rewriteLoadOp(triton::LoadOp op, bool useUnsafeMask = false);

LogicalResult rewriteStoreOp(triton::StoreOp op);
LogicalResult rewriteStoreOp(triton::StoreOp op, bool useUnsafeMask = false);

LogicalResult rewriteOp(Operation *op);
LogicalResult rewriteOp(Operation *op, bool useUnsafeMask = false);
};

} // namespace tts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ def TritonToStructured : Pass<"triton-to-structured", "mlir::ModuleOp"> {
Option<"runPrepassOnly", "run-prepass-only", "bool", /*default*/"false",
"Only run the pre-processing pass which inserts tts.get_structured_state ops used in scf.for">,
Option<"skipPrepass", "skip-prepass", "bool", /*default*/"false",
"Skip the prepass">
"Skip the prepass">,
Option<"useUnsafeMask", "use-unsafe-mask", "bool", /*default*/"false",
"Assume that the mask bounds are never less than starting offsets. May produce incorrect results.">
];
}

Expand Down
16 changes: 14 additions & 2 deletions lib/Analysis/MaskAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

#include "triton-shared/Analysis/MaskAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Support/LogicalResult.h"

Expand Down Expand Up @@ -341,7 +339,21 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
assert(cmpDim != -1 &&
"Unexpected case where no dimension has size larger than 1");

// Important:
// In the case where the values we are loading are entirely masked off like
// the following:
//
// ---|-------|-----------|
// ^ ^ ^
// scalar start end
//
// newEnd = min(end, scalar) = scalar
// Now scalar < start, so simply doing dim = newEnd - start is incorrect.
//
// The correct formula is to optionally move `newDim` back to `start` using
// max(newEnd, start).
auto newEnd = minOFRs(lhsState.end, rhsState.scalar, loc, builder);
newEnd = maxOFRs(newEnd, lhsState.start, loc, builder);
auto newDim = subOFRs(newEnd, lhsState.start, loc, builder);

for (int32_t i = 0; i < lhsState.getRank(); i++) {
Expand Down
28 changes: 28 additions & 0 deletions lib/Analysis/OpFoldResultUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,4 +217,32 @@ OpFoldResult minOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
return minOp.getResult();
}

OpFoldResult maxOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
const Location loc, OpBuilder &b) {
auto lhsIntAttr = getIntAttr(lhs);
auto rhsIntAttr = getIntAttr(rhs);

// both lhs and rhs are constants, return result directly
if (lhsIntAttr && rhsIntAttr)
return b.getIndexAttr(std::max(lhsIntAttr.value(), rhsIntAttr.value()));

// otherwise, need to create instructions to calculate new attribute value
auto lhsValue = dyn_cast<Value>(lhs);
if (lhsIntAttr) {
auto lhsOp =
b.create<arith::ConstantOp>(loc, b.getIndexAttr(lhsIntAttr.value()));
lhsValue = lhsOp.getResult();
}

auto rhsValue = dyn_cast<Value>(rhs);
if (rhsIntAttr) {
auto rhsOp =
b.create<arith::ConstantOp>(loc, b.getIndexAttr(rhsIntAttr.value()));
rhsValue = rhsOp.getResult();
}

auto maxOp = b.create<arith::MaxSIOp>(loc, lhsValue, rhsValue);
return maxOp.getResult();
}

} // namespace mlir
16 changes: 9 additions & 7 deletions lib/AnalysisStructured/PtrAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,8 @@ PtrAnalysis::rewriteGetStructuredStateOp(tts::GetStructuredStateOp op) {
return success();
}

LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op) {
LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op,
bool useUnsafeMask) {
auto ptr = ptrMap.lookupOrNull(op.getPtr());
auto mask = op.getMask();
auto other = op.getOther();
Expand All @@ -1109,7 +1110,7 @@ LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op) {
}

ArrayRef<OpFoldResult> dims;
mlir::triton::MaskState mstate;
mlir::triton::MaskState mstate(useUnsafeMask);
Value scalarOther;

OpBuilder builder(op);
Expand Down Expand Up @@ -1226,7 +1227,8 @@ void PtrAnalysis::initializeMaybeStructuredArgs(Operation *op) {
}
}

LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op) {
LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op,
bool useUnsafeMask) {
auto ptr = ptrMap.lookupOrNull(op.getPtr());
auto val = op.getValue();
auto mask = op.getMask();
Expand All @@ -1245,7 +1247,7 @@ LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op) {
}

ArrayRef<OpFoldResult> dims;
mlir::triton::MaskState mstate;
mlir::triton::MaskState mstate(useUnsafeMask);

OpBuilder builder(op);

Expand All @@ -1270,7 +1272,7 @@ LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op) {
return success();
}

LogicalResult PtrAnalysis::rewriteOp(Operation *rootOp) {
LogicalResult PtrAnalysis::rewriteOp(Operation *rootOp, bool useUnsafeMask) {
LLVM_DEBUG({
llvm::dbgs() << "rewriting rootOp\n";
rootOp->dump();
Expand Down Expand Up @@ -1301,14 +1303,14 @@ LogicalResult PtrAnalysis::rewriteOp(Operation *rootOp) {
return WalkResult::advance();
})
.Case<triton::LoadOp>([&](auto load) {
if (rewriteLoadOp(load).failed()) {
if (rewriteLoadOp(load, useUnsafeMask).failed()) {
load->emitRemark("PtrAnalysis: Failed to rewrite LoadOp");
return WalkResult::advance();
}
return WalkResult::skip();
})
.Case<triton::StoreOp>([&](auto store) {
if (rewriteStoreOp(store).failed()) {
if (rewriteStoreOp(store, useUnsafeMask).failed()) {
store->emitRemark("PtrAnalysis: Failed to rewrite StoreOp");
return WalkResult::advance();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ class TritonToStructuredPass
mlir::tts::PtrAnalysis ptrAnalysis;
ptrAnalysis.initializeMaybeStructuredArgs(moduleOp);

if (failed(ptrAnalysis.rewriteOp(moduleOp))) {
if (failed(ptrAnalysis.rewriteOp(moduleOp, useUnsafeMask))) {
moduleOp->emitWarning("PtrAnalysis failed");
}

Expand Down
39 changes: 39 additions & 0 deletions python/examples/test_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch

import triton
import triton.language as tl

from triton.backends.triton_shared.driver import CPUDriver


def test_mask(device):
@triton.jit
def test(in0, out0):
offs = 100 + tl.arange(0, 4)
out_offs = tl.arange(0, 4)
a = tl.load(in0 + offs, mask=offs < 4, other=-1)
tl.store(out0 + out_offs, a)

SIZE = 8
input = torch.arange(0, SIZE, device=device, dtype=torch.int32)
output = torch.full((SIZE,), -2, device=device, dtype=torch.int32)

if device == 'cpu':
triton.runtime.driver.set_active(CPUDriver())

grid = lambda meta: (1,)

src = triton.compiler.ASTSource(
fn=test,
signature="*fp32,*fp32,i32",
)
ret = triton.compile(
src,
)
print(ret.asm["ttir"])

print(output)
test[grid](input, output)
print(input)
print(output)
torch.testing.assert_close(output, torch.tensor([-1, -1, -1, -1, -2, -2, -2, -2], device=device, dtype=torch.int32))
35 changes: 0 additions & 35 deletions python/examples/test_tensor_index_iterargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,38 +112,3 @@ def test_1(out0):
src,
)
print(ret.asm["ttir"])



def disabled_test_mask(device):
# TODO: This fails to compile in StructuredToMemref
@triton.jit
def test_1(in0, out0, batch):
offs = 4 + tl.arange(0, 4)
out_offs = tl.arange(0, 4)
a = tl.load(in0 + offs, mask=offs < 0, other=-1)
tl.store(out0 + out_offs, a)

# TODO: This segfauls in the CPU backend
# Crashes when the batch value will mask off all of the tensors
@triton.jit
def test_2(in0, out0, batch):
offs = 4 + tl.arange(0, 4)
out_offs = tl.arange(0, 4)
a = tl.load(in0 + offs, mask=offs < 0, other=-1)
tl.store(out0 + out_offs, a)


SIZE = 8
input = torch.arange(0, SIZE, device=device, dtype=torch.int32)
output = torch.full((SIZE,), -1, device=device, dtype=torch.int32)

if device == 'cpu':
triton.runtime.driver.set_active(CPUDriver())

grid = lambda meta: (1,)

print(output)
test_1[grid](input, output, 0)
print(input)
print(output)
Loading

0 comments on commit d8c8f29

Please sign in to comment.