Skip to content

Commit

Permalink
Adding a test for swap kernel (#218)
Browse files Browse the repository at this point in the history
Following up on conversation about copy optimization.

---------

Co-authored-by: Renat Idrisov <parsifal-47@users.noreply.github.com>
  • Loading branch information
parsifal-47 and parsifal-47 authored Jan 13, 2025
1 parent d1c5441 commit 8b9f5dd
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions python/examples/test_swap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch

import triton
import triton.language as tl

# The purpose of this kernel and test is to catch incorrectly optimized kernels
# where copy elimination happens erroneously in the absence of explicit memory allocation.
# Such optimization bugs can result in incorrect behavior when swapping two arrays,
# particularly when both arrays unintentionally end up with the same data due to
# missing intermediate storage or mismanaged memory access.

@triton.jit
def swap_kernel(
x_ptr, # *Pointer* to first inout vector.
y_ptr, # *Pointer* to second inout vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offsets)
y = tl.load(y_ptr + offsets)
tl.store(x_ptr + offsets, y)
tl.store(y_ptr + offsets, x)


def swap(x: torch.Tensor, y: torch.Tensor):
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
swap_kernel[grid](x, y, BLOCK_SIZE=1024)


def test(device):
torch.manual_seed(0)
size = 10240
x = torch.rand(size, device=device)
y = torch.rand(size, device=device)
assert not torch.equal(x, y)
x_ = x.clone()
y_ = y.clone()
swap(x, y)
assert torch.equal(x, y_)
assert torch.equal(y, x_)

0 comments on commit 8b9f5dd

Please sign in to comment.