-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding benchmarks for vector addition, matmul, softmax, and layernorm (…
…#209) Since we cannot use standard triton benchmarks as brought up here: #199 because they are specific to GPU. Sample output: ```sh $ python test_softmax.py bench_softmax(1024, 'torch') {}, 20 times, all results in seconds Wall: Avg=0.006537, min=0.005301, std=0.000326, max=0.006723 CPU: Avg=0.123649, min=0.010989, std=0.026653, max=0.140211 bench_softmax(1024, 'triton') {}, 20 times, all results in seconds Wall: Avg=0.102619, min=0.014122, std=0.384826, max=1.780037 CPU: Avg=0.028643, min=0.014123, std=0.062372, max=0.300513 bench_softmax(2048, 'torch') {}, 20 times, all results in seconds Wall: Avg=0.015215, min=0.013364, std=0.002282, max=0.022841 CPU: Avg=0.172217, min=0.043525, std=0.037402, max=0.231176 bench_softmax(2048, 'triton') {}, 20 times, all results in seconds Wall: Avg=0.071460, min=0.055257, std=0.068684, max=0.370846 CPU: Avg=0.062689, min=0.055258, std=0.030449, max=0.195406 bench_softmax(4096, 'torch') {}, 20 times, all results in seconds Wall: Avg=0.056267, min=0.056117, std=0.000134, max=0.056681 CPU: Avg=0.313888, min=0.220500, std=0.023960, max=0.338866 bench_softmax(4096, 'triton') {}, 20 times, all results in seconds Wall: Avg=0.258867, min=0.244147, std=0.062352, max=0.530646 CPU: Avg=0.249397, min=0.244141, std=0.021087, max=0.341300 ``` --------- Co-authored-by: Renat Idrisov <parsifal-47@users.noreply.github.com>
- Loading branch information
1 parent
20fb38b
commit ebe3845
Showing
6 changed files
with
189 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# this is a benchmark which multiplies square matrices with maximum block size | ||
# to check the performance of tl.dot operation | ||
|
||
import torch | ||
import triton | ||
import triton.language as tl | ||
import benchmark | ||
|
||
|
||
@triton.jit | ||
def bare_matmul(X, Y, Z, M, N, K, BLOCK_SIZE: tl.constexpr): | ||
pid_x = tl.program_id(0) # block row id | ||
pid_y = tl.program_id(1) # block column id | ||
|
||
offs_x = pid_x * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
offs_y = pid_y * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
|
||
x = tl.load(X + offs_x[:, None] * K + offs_y[None, :]) | ||
y = tl.load(Y + offs_x[:, None] * N + offs_y[None, :]) | ||
|
||
z = tl.dot(x, y) | ||
|
||
tl.store(Z + offs_x[:, None] * N + offs_y[None, :], z) | ||
|
||
|
||
@benchmark.measure() | ||
def bench_matmul(N, provider): | ||
device = 'cpu' | ||
dtype = torch.float32 | ||
a = torch.randn((N, N), device=device, dtype=dtype) | ||
b = torch.randn((N, N), device=device, dtype=dtype) | ||
c = torch.empty((N, N), device=device, dtype=dtype) | ||
if provider == 'torch' or provider == 'test': | ||
c_ref = torch.matmul(a, b) | ||
if provider == 'triton' or provider == 'test': | ||
bare_matmul[(1,)](a, b, c, N, N, N, N) | ||
if provider == 'test': | ||
torch.testing.assert_close(c, c_ref, atol=1e-2, rtol=0) | ||
|
||
|
||
if __name__ == "__main__": | ||
benchmark.select_cpu_backend() | ||
for X in [2**i for i in range(7, 10, 1)]: | ||
for provider in ['test', 'torch', 'triton']: | ||
bench_matmul(X, provider) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import time | ||
import numpy as np | ||
from functools import wraps | ||
import triton | ||
from triton.backends.triton_shared.driver import CPUDriver | ||
|
||
|
||
def select_cpu_backend(): | ||
triton.runtime.driver.set_active(CPUDriver()) | ||
|
||
|
||
# Unfortunately, we can't use triton.testing.perf_report and triton.testing.do_bench for CPU backend because | ||
# they are very specific to cuda | ||
|
||
def measure(repeats=20, percentiles=(), timers={'Wall':time.perf_counter, 'CPU':time.process_time}): | ||
""" | ||
Decorator to benchmark a function. | ||
Parameters: | ||
- repeats (int): The number of times the function should be executed for each set of parameters. | ||
- percentiles (tuple): The percentiles to compute on the execution times (e.g., (50, 90, 99)). | ||
- timers (dict): A dictionary where keys are timer names (e.g., 'Wall', 'CPU') and values are timer functions | ||
that measure elapsed time. By default: | ||
* 'Wall': Uses time.perf_counter for high-resolution wall-clock time. | ||
* 'CPU': Uses time.process_time for CPU time spent by the process. | ||
Returns: | ||
- A decorated function that prints: | ||
* Average execution time. | ||
* Standard deviation time. | ||
* Minimum and maximum times. | ||
* Computed percentiles for each timer. | ||
""" | ||
def decorator(func): | ||
@wraps(func) | ||
def wrapper(*args, **kwargs): | ||
print(f"{func.__name__}{args} {kwargs}, {repeats} times, all results in seconds") | ||
times = {} | ||
for t, _ in timers.items(): | ||
times[t] = [] | ||
|
||
for _ in range(repeats): | ||
starts = {} | ||
for t, f in timers.items(): | ||
starts[t] = f() | ||
|
||
result = func(*args, **kwargs) | ||
|
||
for t, f in timers.items(): | ||
times[t].append(f() - starts[t]) | ||
|
||
for t, _ in timers.items(): | ||
average_time = np.mean(times[t]) | ||
min_time = np.min(times[t]) | ||
max_time = np.max(times[t]) | ||
computed_percentiles = np.percentile(times[t], percentiles) | ||
std_dev_time = np.std(times[t]) | ||
|
||
print(f"{t}: Avg={average_time:.6f}, min={min_time:.6f}, std={std_dev_time:.6f},", end=" ") | ||
for p, value in zip(percentiles, computed_percentiles): | ||
print(f"{p}pp={value:.6f},", end=" ") | ||
print(f"max={max_time:.6f}") | ||
|
||
return result | ||
return wrapper | ||
return decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters