-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix mask analysis for when the entire tensor is masked off (#186)
The current formula for computing masks does not work when the mask bound is smaller than the start of the mask range: ``` ---|-------|-----------| ^ ^ ^ bound start end ``` Current formula: ``` new_end = min(end, bound) new_dim = new_end - start ``` For the above case, this formula will produce a negative `new_dim`. To fix this issue, we optionally move `new_end` back to `start` so that when `bound < start`, `new_dim` is 0. The new formula is: ``` new_end_tmp = min(end, bound) new_end = max(new_end_tmp, start) new_dim = new_end - start ``` Another formula that could work in theory is to do: ``` new_end = min(end, bound) new_dim_potentially_neg = new_end - start new_dim = max(new_dim_potentially_neg, 0) ``` But this approach does not work in MaskAnalysis because we operate on the `index` type which is unsigned. We would have a negative overflow when computing `new_dim_potentially_neg` and end up getting a positive number instead. # Changes + Update the formula + The change is quite invasive, so I added a flag in cases we don't want to enable this fix + Update lit tests + Removed some of the old TritonToLinalg tests; we will remove the old pass in a future PR
- Loading branch information
1 parent
177a624
commit d8c8f29
Showing
38 changed files
with
819 additions
and
1,874 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
|
||
import triton | ||
import triton.language as tl | ||
|
||
from triton.backends.triton_shared.driver import CPUDriver | ||
|
||
|
||
def test_mask(device): | ||
@triton.jit | ||
def test(in0, out0): | ||
offs = 100 + tl.arange(0, 4) | ||
out_offs = tl.arange(0, 4) | ||
a = tl.load(in0 + offs, mask=offs < 4, other=-1) | ||
tl.store(out0 + out_offs, a) | ||
|
||
SIZE = 8 | ||
input = torch.arange(0, SIZE, device=device, dtype=torch.int32) | ||
output = torch.full((SIZE,), -2, device=device, dtype=torch.int32) | ||
|
||
if device == 'cpu': | ||
triton.runtime.driver.set_active(CPUDriver()) | ||
|
||
grid = lambda meta: (1,) | ||
|
||
src = triton.compiler.ASTSource( | ||
fn=test, | ||
signature="*fp32,*fp32,i32", | ||
) | ||
ret = triton.compile( | ||
src, | ||
) | ||
print(ret.asm["ttir"]) | ||
|
||
print(output) | ||
test[grid](input, output) | ||
print(input) | ||
print(output) | ||
torch.testing.assert_close(output, torch.tensor([-1, -1, -1, -1, -2, -2, -2, -2], device=device, dtype=torch.int32)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.