Skip to content

Commit

Permalink
Merge pull request #671 from ademeure/less_slow_inference
Browse files Browse the repository at this point in the history
Faster inference by changing (B,T) to (1,t)
  • Loading branch information
karpathy authored Jul 11, 2024
2 parents db2454f + 2fd6fee commit bdb0fb5
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,12 @@ void* malloc_and_point_activations(TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS]

void* acts_memory;
cudaCheck(cudaMalloc((void**)&acts_memory, bytes));

// cudaMalloc does not guarantee initial memory values so we memset the allocation here
// this matters because e.g. non-cuDNN attention assumes the attention buffer is zeroed
// todo - up to ~100ms on slow GPUs, could theoretically be more selective, but this is safer
cudaCheck(cudaMemset(acts_memory, 0, bytes));

char* acts_memory_iterator = (char*)acts_memory;
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
// extra protection so we don't accidentally use an empty buffer
Expand Down Expand Up @@ -610,9 +616,9 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
cudaCheck(cudaMalloc(((void**)&model->accumulated_mean_loss), sizeof(float)));
cudaCheck(cudaMallocHost((void**)&model->cpu_losses, B * T * sizeof(float)));
} else {
// validate B,T is consistent with how we've allocated the memory before
// in principle we could get more clever here in the future, for now this is safest
if (B != model->batch_size || T != model->seq_len) {
// validate B,T are not larger than the values used at initialisation
// (smaller B,T are okay for inference only)
if (B > model->batch_size || T > model->seq_len) {
printf("Model: B=%d T=%d, Desired: B=%d T=%d\n", model->batch_size, model->seq_len, (int)B, (int)T);
exit(EXIT_FAILURE);
}
Expand Down Expand Up @@ -671,6 +677,9 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
attention_forward_cudnn(l_atty, (float*)l_att, l_qkvr, B, T, NH, C, main_stream);
#else
floatX* l_att = acts.att + l * B * NH * T * T;
if (T != model->seq_len) { // unused parts of attention buffer must be zeroed (T-dependent)
cudaCheck(cudaMemset(l_att, 0, B * NH * T * T * sizeof(floatX)));
}
// these are only needed as scratchpads for the forward pass, but
// need not be stored for backward
matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, main_stream);
Expand Down Expand Up @@ -1717,14 +1726,14 @@ int main(int argc, char *argv[]) {
printf("generating:\n---\n");
for (int t = 1; t < genT; t++) {
NvtxRange generation_range("Generation step", t);
// note that inference is very wasteful here because for each token
// we re-calculate the forward pass for all of (B,T) positions from scratch
// but the inference here is just for sanity checking anyway
// and we can maybe optimize a bit more later, with careful tests
gpt2_forward(&model, gen_tokens, B, T);
// furthermore, below we're only using b=0 (i.e. the first row) of all B rows
// we're in principle running B "inference streams" in parallel here
// only using position 0 because it's a bit faster (copy less probs from GPU -> CPU)
// we try not to be too wasteful for inference by not calculating all of B,T
// Using a smaller B is always bit-for-bit identical, but T is more tricky
// for non-CUDNN, we need to make sure the attention buffer is memset to 0
// for cuDNN, it might suddenly decide to use a slightly different algorithm...
// on cuDNN 9.2.1 with cuDNN FrontEnd 1.5.2, T >= 256 seems bit-for-bit identical
// (but even if it wasn't fully identical that's probably not the end of the world)
// note this is still somewhat wasteful because we don't have a KV cache!
gpt2_forward(&model, gen_tokens, 1, CEIL_DIV(t, min(T,256)) * min(T,256));
// get the V-dimensional vector probs[0, t-1, :]
floatX* logits = model.acts.output + (t - 1) * model.config.padded_vocab_size;
// move probs back to CPU and sample (note we only move the first vocab_size logits, ignoring the padding)
Expand Down

0 comments on commit bdb0fb5

Please sign in to comment.