Skip to content

Commit

Permalink
Merge pull request #712 from indianspeedster/dev-cuda
Browse files Browse the repository at this point in the history
Added permute kernel in dev/cuda
  • Loading branch information
karpathy authored Jul 28, 2024
2 parents 7379cd1 + a80bce6 commit b4623bc
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 1 deletion.
4 changes: 3 additions & 1 deletion dev/cuda/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ MPI_PATHS = -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-
$(NVCC) $(CFLAGS) $(NVCCFLAGS) $< -o $@

# Build all targets
TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_backward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward fused_residual_forward global_norm
TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_backward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward fused_residual_forward global_norm permute
all: $(TARGETS)
all_ptx: $(TARGETS:%=%.ptx)
all_sass: $(TARGETS:%=%.sass)
Expand Down Expand Up @@ -64,6 +64,8 @@ matmul_backward: matmul_backward.cu
adamw: adamw.cu
global_norm: global_norm.cu

permute: permute.cu

# NCCL communication kernels
nccl_all_reduce: nccl_all_reduce.cu
$(NVCC) -lmpi -lnccl $(NVCCFLAGS) $(MPI_PATHS) nccl_all_reduce.cu -o nccl_all_reduce
Expand Down
199 changes: 199 additions & 0 deletions dev/cuda/permute.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
/*
Kernels to demonstrate permute operation.
Compile example:
nvcc -O3 permute.cu -o permute
The goal is to permute a 4D matrix from its original shape (dim1, dim2, dim3, dim4) to a new shape (dim4, dim3, dim1, dim2).
Before permutation, we need to understand how to access elements in a flattened (linear) form of the matrix.
Given:
dim1 = size of the 1st dimension
dim2 = size of the 2nd dimension
dim3 = size of the 3rd dimension
dim4 = size of the 4th dimension
For any element in a 4D matrix at position (i1, i2, i3, i4), where:
i1 is the index in dimension 1
i2 is the index in dimension 2
i3 is the index in dimension 3
i4 is the index in dimension 4
If you find it challenging to calculate the indices i1, i2, i3, and i4, observe the pattern in the index calculations.
Initially, it might take some time to grasp, but with practice, you'll develop a mental model for it.
To calculate the indices, use the following formulas:
i1 = (idx / (dim2 * dim3 * dim4)) % dim1;
i2 = (idx / (dim3 * dim4)) % dim2;
i3 = (idx / dim4) % dim3;
i4 = idx % dim4;
Pattern Explanation:
To find the index for any dimension, divide the thread ID (idx) by the product of all subsequent dimensions.
Then, perform modulo operation with the current dimension.
The linear index in a flattened 1D array is calculated as:
linear_idx = i1 × ( dim2 × dim3 × dim4 ) + i2 × ( dim3 × dim4 ) + i3 × dim4 + i4
This linear index uniquely identifies the position of the element in the 1D array.
To permute the matrix, we need to rearrange the indices according to the new shape.
In this case, we are permuting from (dim1, dim2, dim3, dim4) to (dim4, dim3, dim1, dim2).
The new dimension post permutation will be as follow:
dim1 becomes the new 3rd dimension.
dim2 becomes the new 4th dimension.
dim3 becomes the new 2nd dimension.
dim4 becomes the new 1st dimension.
permuted_idx = i4 * (dim3 * dim1 * dim2) + i3 * (dim1 * dim2) + i1 * dim2 + i2;
Here's how this works:
i4 * (dim3 * dim1 * dim2): This accounts for how many complete dim3 × dim1 × dim2 blocks fit before the current i4 block.
i3 * (dim1 * dim2): This accounts for the offset within the current i4 block, specifying which i3 block we are in.
i1 * dim2: This accounts for the offset within the current i3 block, specifying which i1 block we are in.
i2: This gives the offset within the current i1 block.
Lastly at the end we store the current value at idx index of the original value to the permuted index in the permuted_matrix.
--------------------------------------------------------------------------------------------------------------------------------------------------------
Similarly we can follow the above approach to permute matrices of any dimensions.
*/


#include <cuda_runtime.h>
#include <stdio.h>
#include <stdlib.h>
#include <cmath>

// CPU function to permute a 4D matrix
void permute_cpu(const float* matrix, float* out_matrix, int dim1, int dim2, int dim3, int dim4) {
int total_threads = dim1 * dim2 * dim3 * dim4;

for (int idx = 0; idx < total_threads; idx++) {
// Calculate the 4D indices from the linear index
int i1 = (idx / (dim2 * dim3 * dim4)) % dim1;
int i2 = (idx / (dim3 * dim4)) % dim2;
int i3 = (idx / dim4) % dim3;
int i4 = idx % dim4;

// Compute the new index for the permuted matrix
// Transpose from (dim1, dim2, dim3, dim4) to (dim4, dim3, dim1, dim2)
int permuted_idx = i4 * (dim3 * dim1 * dim2) + i3 * (dim1 * dim2) + i1 * dim2 + i2;
out_matrix[permuted_idx] = matrix[idx];
}
}

// CUDA kernel to permute a 4D matrix
__global__ void permute_cuda(const float* matrix, float* out_matrix, int dim1, int dim2, int dim3, int dim4) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;

// Ensure index is within bounds
if (idx < dim1 * dim2 * dim3 * dim4) {
// Calculate the 4D indices from the linear index
int i1 = (idx / (dim2 * dim3 * dim4)) % dim1;
int i2 = (idx / (dim3 * dim4)) % dim2;
int i3 = (idx / dim4) % dim3;
int i4 = idx % dim4;

// Compute the new index for the permuted matrix
// Transpose from (dim1, dim2, dim3, dim4) to (dim4, dim3, dim1, dim2)
int permuted_idx = i4 * (dim3 * dim1 * dim2) + i3 * (dim1 * dim2) + i1 * dim2 + i2;
out_matrix[permuted_idx] = matrix[idx];
}
}

// Function to check if the CUDA permutation result matches the CPU result
bool verify_results(const float* permuted_matrix_cuda, const float* permuted_matrix_cpu, int totalElements) {
bool success = true;
for (int i = 0; i < totalElements; i++) {
// Allow a small tolerance for floating-point comparison
if (fabs(permuted_matrix_cuda[i] - permuted_matrix_cpu[i]) > 1e-5) {
success = false;
printf("Permute Operation Failed\n");
printf("CPU: %f\n", permuted_matrix_cpu[i]);
printf("CUDA: %f\n", permuted_matrix_cuda[i]);
break; // Exit early on the first failure
}
}
if (success) {
printf("Permute Operation Passed\n");
}
return success;
}

// Function to initialize the matrix with random values
void initialize_matrix(float* mat, int dim_1, int dim_2, int dim_3, int dim_4) {
for (int i = 0; i < dim_1 * dim_2 * dim_3 * dim_4; ++i) {
mat[i] = static_cast<float>(rand()) / RAND_MAX;
}
printf("Matrix Initialized\n");
}

int main() {
int dim_1 = 24;
int dim_2 = 42;
int dim_3 = 20;
int dim_4 = 32;

// Set up the device
int deviceIdx = 0;
cudaSetDevice(deviceIdx);
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, deviceIdx);
printf("Device %d: %s\n", deviceIdx, deviceProp.name);

// Allocate host memory
float* matrix = (float*)malloc(dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));
float* permuted_matrix = (float*)malloc(dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));
float* permuted_matrix_cpu = (float*)malloc(dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));

// Initialize the matrix with random values
initialize_matrix(matrix, dim_1, dim_2, dim_3, dim_4);

// Allocate device memory
float *d_matrix, *d_permuted_matrix;
cudaMalloc(&d_matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));
cudaMalloc(&d_permuted_matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));

// Copy matrix from host to device
cudaMemcpy(d_matrix, matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float), cudaMemcpyHostToDevice);

// Perform permutation on CPU
permute_cpu(matrix, permuted_matrix_cpu, dim_1, dim_2, dim_3, dim_4);

// Define block and grid sizes
dim3 blockSize(256);
int totalThreads = dim_1 * dim_2 * dim_3 * dim_4;
int gridSize = (totalThreads + blockSize.x - 1) / blockSize.x; // Compute grid size

// Launch CUDA kernel to perform permutation
permute_cuda<<<gridSize, blockSize>>>(d_matrix, d_permuted_matrix, dim_1, dim_2, dim_3, dim_4);
cudaDeviceSynchronize(); // Ensure kernel execution is complete

// Copy the result from device to host
cudaMemcpy(permuted_matrix, d_permuted_matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float), cudaMemcpyDeviceToHost);

// Verify results
verify_results(permuted_matrix, permuted_matrix_cpu, dim_1 * dim_2 * dim_3 * dim_4);

// Free allocated memory
free(matrix);
free(permuted_matrix);
free(permuted_matrix_cpu);
cudaFree(d_matrix);
cudaFree(d_permuted_matrix);

return 0;
}

0 comments on commit b4623bc

Please sign in to comment.