Skip to content

Commit

Permalink
rename to _forward for the kernel, and small fixes to docs, and avoid…
Browse files Browse the repository at this point in the history
… auto
  • Loading branch information
karpathy committed Apr 22, 2024
1 parent 7830cf6 commit b1e5595
Showing 1 changed file with 39 additions and 20 deletions.
59 changes: 39 additions & 20 deletions dev/cuda/trimat.cu → dev/cuda/trimat_forward.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,27 @@
// Triangular matrix multiplication as in autoregressive attention. A short story.
//
/*
Triangular matrix multiplication as in autoregressive attention. A short story.
by @ngc92
Compile:
nvcc -O3 --use_fast_math trimat_forward.cu -o trimat_forward -lcublas
Run:
cuBLAS baseline kernel
./trimat_forward 0
naive
./trimat_forward 1
registers
./trimat_forward 2
tri3
./trimat_forward 3
tri4
./trimat_forward 4
*/

#include <stdio.h>
#include <stdlib.h>
Expand All @@ -14,7 +36,6 @@
static cublasHandle_t cublas_handle;
static float* d_qkvr; // scratch for the cublas kernel


/* ** Chapter I - Introduction **
*
* You are Trimul. You've always wanted to do fast matrix multiplication, but they said
Expand Down Expand Up @@ -152,7 +173,10 @@ void trimul_cublas(float* preatt,
* Let's observe how we're doing.
*/

template<auto matmul_tri>
// using creates an alias for a function pointer
using matmul_fn_ptr = void(*)(float* p, int ps, const float* k, int ks, const float* q, int qs, int T, int hs, float alpha);

template<matmul_fn_ptr matmul_tri>
__global__ void __launch_bounds__(256, 2) trimul_global(float* out, const float* inp, int T, int C, int NH) {
// skip above the diagonal
if(blockIdx.y > blockIdx.x)
Expand All @@ -176,9 +200,8 @@ __global__ void __launch_bounds__(256, 2) trimul_global(float* out, const float*
matmul_tri(r, T, q, C3, k, C3, T, hs, scale);
}

template<auto matmul_tri>
void trimul_launcher(float* out, const float* inp,
int B, int T, int C, int NH) {
template<matmul_fn_ptr matmul_tri>
void trimul_launcher(float* out, const float* inp, int B, int T, int C, int NH) {
// we assume nice shapes here. Let's not make the code a mess by supporting weird shapes that you
// wouldn't want to use anyway.
assert(T % 128 == 0);
Expand Down Expand Up @@ -240,7 +263,6 @@ __device__ void matmul_tri_naive(float* p, int ps, const float* k, int ks, const
}
}


/* ** Chapter IV - ... **
*
* Each worker is producing 64 combined cookies from 8 animals and 8 landscapes. They send there runners of 64 times
Expand Down Expand Up @@ -387,8 +409,6 @@ __device__ void matmul_tri3(float* p, int ps, const float* k, int ks, const floa
}
}



/* ** Chapter V - Sharing is Caring **
*
* You take a look around the shed, and see that there are 32 shelves there. They are much larger than the workbenches,
Expand Down Expand Up @@ -470,19 +490,18 @@ __device__ void matmul_tri4(float* p, int ps, const float* k, int ks, const floa
}
}


/* ** Chapter VI - Competition Day **
*
* Finally, you feel ready to take on Cublas. You hand out tickets to the event for you friends to see.
*
* -----------------------------------------------------------------
* | CuBLAS vs TriMul - Fight of the Century |
* | |
* | Ticket code: |
* | > nvcc -O3 --use_fast_math trimat.cu -o trimat -lcublas |
* | > ./trimat 4 |
* | |
* -----------------------------------------------------------------
*
* ---------------------------------------------------------------------------------
* | CuBLAS vs TriMul - Fight of the Century |
* | |
* | Ticket code: |
* | > nvcc -O3 --use_fast_math trimat_forward.cu -o trimat_forward -lcublas |
* | > ./trimat 4 |
* | |
* ---------------------------------------------------------------------------------
*/

void trimul_gpu(int kernel_num,
Expand Down

0 comments on commit b1e5595

Please sign in to comment.