Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to run 02-fused-softmax.py on a CPU? #199

Closed
banach-space opened this issue Dec 6, 2024 · 2 comments
Closed

How to run 02-fused-softmax.py on a CPU? #199

banach-space opened this issue Dec 6, 2024 · 2 comments

Comments

@banach-space
Copy link

Hi folks,

First off, thank you for triton-shared—it's fantastic work!

I’ve successfully built the project and run the examples on my AArch64 machine. However, I’m running into an issue with 02-fused-softmax.py:

AssertionError: Torch not compiled with CUDA enabled

Clearly, something is trying to target CUDA, but I’m not sure what. I modified the example to explicitly select the CPU backend:

diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py
index 57f027277..8b10eb3d3 100644
--- a/python/tutorials/02-fused-softmax.py
+++ b/python/tutorials/02-fused-softmax.py
@@ -27,6 +27,8 @@ import triton
 import triton.language as tl
 from triton.runtime import driver

+from triton.backends.triton_shared.driver import CPUDriver
+

 def is_hip():
     return triton.runtime.driver.active.get_current_target().backend == "hip"
@@ -110,12 +112,9 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n
 # %%
 # We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

-device = torch.cuda.current_device()
-properties = driver.active.utils.get_device_properties(device)
-NUM_SM = properties["multiprocessor_count"]
-NUM_REGS = properties["max_num_regs"]
-SIZE_SMEM = properties["max_shared_mem"]
-WARP_SIZE = properties["warpSize"]
+triton.runtime.driver.set_active(CPUDriver())
+device = 'cpu'
+# device = torch.cuda.current_device()
 target = triton.runtime.driver.active.get_current_target()
 kernels = {}

@@ -146,25 +145,7 @@ def softmax(x):
         kernel._init_handles()
         n_regs = kernel.n_regs
         size_smem = kernel.metadata.shared
-        if is_hip():
-            # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
-            # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
-            # ISA SECTION (3.6.4 for CDNA3)
-            # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
-            # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
-            # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
-            # not required to be equal numbers of both types.
-            if is_cdna():
-                NUM_GPRS = NUM_REGS * 2
-
-            # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
-            # When we divide this number with WARP_SIZE we get maximum number of waves that can
-            # execute on a CU (multi-processor)  in parallel.
-            MAX_NUM_THREADS = properties["max_threads_per_sm"]
-            max_num_waves = MAX_NUM_THREADS // WARP_SIZE
-            occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
-        else:
-            occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
+        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
         occupancy = min(occupancy, SIZE_SMEM // size_smem)
         num_programs = NUM_SM * occupancy
         kernels[BLOCK_SIZE] = (kernel, num_programs)

Unfortunately, I’m still getting the same error, so I suspect my changes might not be sufficient. I’m very new to this and probably missing something obvious - please bear with me 😅

Any guidance would be greatly appreciated. Let me know if you need additional logs or details from my setup!

Thanks,
Andrzej

@parsifal-47
Copy link
Contributor

Hi Andrzej,
the perf part internally very specific to GPU, that is why in my PR here #163 I do not use it to benchmark triton on CPU. I would recommend to remove "perf" part entirely and retry.

Thank you!

@banach-space
Copy link
Author

Hi Renat, thanks for getting back to me so quickly and for the link. Yes, removing "perf" the perf part fixes the issue :)

Let me close this and just follow what you did in #163. It would be nice to have it merged in :)

-Andrzej

nhat-nguyen pushed a commit that referenced this issue Jan 3, 2025
…#209)

Since we cannot use standard triton benchmarks as brought up here:
#199
because they are specific to GPU. Sample output:
```sh
$ python test_softmax.py
bench_softmax(1024, 'torch') {}, 20 times, all results in seconds
Wall: Avg=0.006537, min=0.005301, std=0.000326, max=0.006723
CPU: Avg=0.123649, min=0.010989, std=0.026653, max=0.140211
bench_softmax(1024, 'triton') {}, 20 times, all results in seconds
Wall: Avg=0.102619, min=0.014122, std=0.384826, max=1.780037
CPU: Avg=0.028643, min=0.014123, std=0.062372, max=0.300513
bench_softmax(2048, 'torch') {}, 20 times, all results in seconds
Wall: Avg=0.015215, min=0.013364, std=0.002282, max=0.022841
CPU: Avg=0.172217, min=0.043525, std=0.037402, max=0.231176
bench_softmax(2048, 'triton') {}, 20 times, all results in seconds
Wall: Avg=0.071460, min=0.055257, std=0.068684, max=0.370846
CPU: Avg=0.062689, min=0.055258, std=0.030449, max=0.195406
bench_softmax(4096, 'torch') {}, 20 times, all results in seconds
Wall: Avg=0.056267, min=0.056117, std=0.000134, max=0.056681
CPU: Avg=0.313888, min=0.220500, std=0.023960, max=0.338866
bench_softmax(4096, 'triton') {}, 20 times, all results in seconds
Wall: Avg=0.258867, min=0.244147, std=0.062352, max=0.530646
CPU: Avg=0.249397, min=0.244141, std=0.021087, max=0.341300
```

---------

Co-authored-by: Renat Idrisov <parsifal-47@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants