From ebe38458531ad02286859540e39fe6aabdf80e36 Mon Sep 17 00:00:00 2001 From: Renat Idrisov <4032256+parsifal-47@users.noreply.github.com> Date: Fri, 3 Jan 2025 09:27:13 -0800 Subject: [PATCH] Adding benchmarks for vector addition, matmul, softmax, and layernorm (#209) Since we cannot use standard triton benchmarks as brought up here: https://github.com/microsoft/triton-shared/issues/199 because they are specific to GPU. Sample output: ```sh $ python test_softmax.py bench_softmax(1024, 'torch') {}, 20 times, all results in seconds Wall: Avg=0.006537, min=0.005301, std=0.000326, max=0.006723 CPU: Avg=0.123649, min=0.010989, std=0.026653, max=0.140211 bench_softmax(1024, 'triton') {}, 20 times, all results in seconds Wall: Avg=0.102619, min=0.014122, std=0.384826, max=1.780037 CPU: Avg=0.028643, min=0.014123, std=0.062372, max=0.300513 bench_softmax(2048, 'torch') {}, 20 times, all results in seconds Wall: Avg=0.015215, min=0.013364, std=0.002282, max=0.022841 CPU: Avg=0.172217, min=0.043525, std=0.037402, max=0.231176 bench_softmax(2048, 'triton') {}, 20 times, all results in seconds Wall: Avg=0.071460, min=0.055257, std=0.068684, max=0.370846 CPU: Avg=0.062689, min=0.055258, std=0.030449, max=0.195406 bench_softmax(4096, 'torch') {}, 20 times, all results in seconds Wall: Avg=0.056267, min=0.056117, std=0.000134, max=0.056681 CPU: Avg=0.313888, min=0.220500, std=0.023960, max=0.338866 bench_softmax(4096, 'triton') {}, 20 times, all results in seconds Wall: Avg=0.258867, min=0.244147, std=0.062352, max=0.530646 CPU: Avg=0.249397, min=0.244141, std=0.021087, max=0.341300 ``` --------- Co-authored-by: Renat Idrisov --- python/examples/bare_matmul.py | 45 +++++++++++++++++++++ python/examples/benchmark.py | 66 +++++++++++++++++++++++++++++++ python/examples/test_layernorm.py | 25 ++++++++++++ python/examples/test_matmul.py | 19 ++++++++- python/examples/test_softmax.py | 18 +++++++++ python/examples/test_vec_add.py | 17 ++++++++ 6 files changed, 189 insertions(+), 1 deletion(-) create mode 100644 python/examples/bare_matmul.py create mode 100644 python/examples/benchmark.py diff --git a/python/examples/bare_matmul.py b/python/examples/bare_matmul.py new file mode 100644 index 00000000..eaa408b1 --- /dev/null +++ b/python/examples/bare_matmul.py @@ -0,0 +1,45 @@ +# this is a benchmark which multiplies square matrices with maximum block size +# to check the performance of tl.dot operation + +import torch +import triton +import triton.language as tl +import benchmark + + +@triton.jit +def bare_matmul(X, Y, Z, M, N, K, BLOCK_SIZE: tl.constexpr): + pid_x = tl.program_id(0) # block row id + pid_y = tl.program_id(1) # block column id + + offs_x = pid_x * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_y = pid_y * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + x = tl.load(X + offs_x[:, None] * K + offs_y[None, :]) + y = tl.load(Y + offs_x[:, None] * N + offs_y[None, :]) + + z = tl.dot(x, y) + + tl.store(Z + offs_x[:, None] * N + offs_y[None, :], z) + + +@benchmark.measure() +def bench_matmul(N, provider): + device = 'cpu' + dtype = torch.float32 + a = torch.randn((N, N), device=device, dtype=dtype) + b = torch.randn((N, N), device=device, dtype=dtype) + c = torch.empty((N, N), device=device, dtype=dtype) + if provider == 'torch' or provider == 'test': + c_ref = torch.matmul(a, b) + if provider == 'triton' or provider == 'test': + bare_matmul[(1,)](a, b, c, N, N, N, N) + if provider == 'test': + torch.testing.assert_close(c, c_ref, atol=1e-2, rtol=0) + + +if __name__ == "__main__": + benchmark.select_cpu_backend() + for X in [2**i for i in range(7, 10, 1)]: + for provider in ['test', 'torch', 'triton']: + bench_matmul(X, provider) diff --git a/python/examples/benchmark.py b/python/examples/benchmark.py new file mode 100644 index 00000000..f82d3441 --- /dev/null +++ b/python/examples/benchmark.py @@ -0,0 +1,66 @@ +import time +import numpy as np +from functools import wraps +import triton +from triton.backends.triton_shared.driver import CPUDriver + + +def select_cpu_backend(): + triton.runtime.driver.set_active(CPUDriver()) + + +# Unfortunately, we can't use triton.testing.perf_report and triton.testing.do_bench for CPU backend because +# they are very specific to cuda + +def measure(repeats=20, percentiles=(), timers={'Wall':time.perf_counter, 'CPU':time.process_time}): + """ + Decorator to benchmark a function. + + Parameters: + - repeats (int): The number of times the function should be executed for each set of parameters. + - percentiles (tuple): The percentiles to compute on the execution times (e.g., (50, 90, 99)). + - timers (dict): A dictionary where keys are timer names (e.g., 'Wall', 'CPU') and values are timer functions + that measure elapsed time. By default: + * 'Wall': Uses time.perf_counter for high-resolution wall-clock time. + * 'CPU': Uses time.process_time for CPU time spent by the process. + + Returns: + - A decorated function that prints: + * Average execution time. + * Standard deviation time. + * Minimum and maximum times. + * Computed percentiles for each timer. + """ + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + print(f"{func.__name__}{args} {kwargs}, {repeats} times, all results in seconds") + times = {} + for t, _ in timers.items(): + times[t] = [] + + for _ in range(repeats): + starts = {} + for t, f in timers.items(): + starts[t] = f() + + result = func(*args, **kwargs) + + for t, f in timers.items(): + times[t].append(f() - starts[t]) + + for t, _ in timers.items(): + average_time = np.mean(times[t]) + min_time = np.min(times[t]) + max_time = np.max(times[t]) + computed_percentiles = np.percentile(times[t], percentiles) + std_dev_time = np.std(times[t]) + + print(f"{t}: Avg={average_time:.6f}, min={min_time:.6f}, std={std_dev_time:.6f},", end=" ") + for p, value in zip(percentiles, computed_percentiles): + print(f"{p}pp={value:.6f},", end=" ") + print(f"max={max_time:.6f}") + + return result + return wrapper + return decorator \ No newline at end of file diff --git a/python/examples/test_layernorm.py b/python/examples/test_layernorm.py index 56134c0b..dc9b8471 100644 --- a/python/examples/test_layernorm.py +++ b/python/examples/test_layernorm.py @@ -23,6 +23,7 @@ import triton import triton.language as tl import pytest +import benchmark @triton.jit @@ -133,3 +134,27 @@ def test_layer_norm(M, N, dtype, eps, device): # compare #assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) + + +@benchmark.measure() +def bench_layernorm(size, provider): + layer_norm = LayerNorm.apply + device = 'cpu' + eps = 1e-5 + dtype = torch.float16 + x_shape = (size, size) + w_shape = (x_shape[-1], ) + weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=False) + bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=False) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + dy = .1 * torch.randn_like(x) + x.requires_grad_(False) + # forward pass + y_tri = layer_norm(x, w_shape, weight, bias, eps, device) + + +if __name__ == "__main__": + benchmark.select_cpu_backend() + for X in [2**i for i in range(10, 13, 1)]: + for provider in ['triton']: + bench_layernorm(X, provider) \ No newline at end of file diff --git a/python/examples/test_matmul.py b/python/examples/test_matmul.py index 2281f7d4..470006a5 100644 --- a/python/examples/test_matmul.py +++ b/python/examples/test_matmul.py @@ -2,7 +2,7 @@ import triton import triton.language as tl - +import benchmark # `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: # - A list of `triton.Config` objects that define different configurations of @@ -155,3 +155,20 @@ def test_matmul(device): triton_output = matmul(a, b) torch_output = torch.matmul(a, b) torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0) + + +@benchmark.measure() +def bench_matmul(M, N, K, provider): + a = torch.randn((M, K), device='cpu', dtype=torch.float32) + b = torch.randn((K, N), device='cpu', dtype=torch.float32) + if provider == 'torch': + torch.matmul(a, b) + if provider == 'triton': + matmul(a, b) + + +if __name__ == "__main__": + benchmark.select_cpu_backend() + for X in [128 * i for i in range(2, 7)]: + for provider in ['torch', 'triton']: + bench_matmul(X, X, X, provider) diff --git a/python/examples/test_softmax.py b/python/examples/test_softmax.py index 3508442c..b3d43dfa 100644 --- a/python/examples/test_softmax.py +++ b/python/examples/test_softmax.py @@ -2,6 +2,7 @@ import triton import triton.language as tl +import benchmark @triton.jit @@ -62,3 +63,20 @@ def test_softmax(device): y_triton = softmax(x) y_torch = torch.softmax(x, axis=1) assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) + + +@benchmark.measure() +def bench_softmax(size, provider): + torch.manual_seed(0) + x = torch.randn(size, size, device='cpu') + if provider == 'torch': + torch.softmax(x, axis=1) + if provider == 'triton': + softmax(x) + + +if __name__ == "__main__": + benchmark.select_cpu_backend() + for X in [2**i for i in range(10, 14, 1)]: + for provider in ['torch', 'triton']: + bench_softmax(X, provider) \ No newline at end of file diff --git a/python/examples/test_vec_add.py b/python/examples/test_vec_add.py index 4dc43139..db2fa098 100644 --- a/python/examples/test_vec_add.py +++ b/python/examples/test_vec_add.py @@ -2,6 +2,7 @@ import triton import triton.language as tl +import benchmark @triton.jit @@ -66,3 +67,19 @@ def test(device): f"The maximum difference between torch and triton is " f"{torch.max(torch.abs(output_torch - output_triton))}" ) + +@benchmark.measure() +def bench_vecadd(size, provider): + a = torch.rand(size, device='cpu', dtype=torch.float32) + b = torch.rand(size, device='cpu', dtype=torch.float32) + if provider == 'torch': + a + b + if provider == 'triton': + add(a, b) + + +if __name__ == "__main__": + benchmark.select_cpu_backend() + for X in [2**i for i in range(22, 25, 1)]: + for provider in ['torch', 'triton']: + bench_vecadd(X, provider) \ No newline at end of file