From 1c832842542ecf189b093b9e114b79a5c0c875c2 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Mon, 22 Jul 2024 17:51:14 +0200 Subject: [PATCH] Add conditional kv computation to encoder --- llmc/encoder.cuh | 19 ++++++++++++------- llmc/layernorm.cuh | 4 ++-- train_gpt2.cu | 24 ++++++++++++++++-------- 3 files changed, 30 insertions(+), 17 deletions(-) diff --git a/llmc/encoder.cuh b/llmc/encoder.cuh index 3aa63e175..f58e089fe 100644 --- a/llmc/encoder.cuh +++ b/llmc/encoder.cuh @@ -18,14 +18,14 @@ In the backward pass, the gradients flow to both, handled by different kernels __global__ void encoder_forward_kernel3(floatX* out, const int* inp, const floatX* wte, const floatX* wpe, - int B, int T, int C) { + int B, int T, int C, int use_kv) { int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; - int N = B * T * C; + int N = B * (use_kv ? 1 : T) * C; if (idx >= N) { return; } int bt = idx / C; - int b = bt / T; - int t = bt % T; + int b = bt / (use_kv ? 1 : T); + int t = use_kv ? 0 : bt % T; int c = idx % C; int ix = inp[b * T + t]; @@ -156,12 +156,17 @@ __global__ void wpe_backward_kernel(floatX* dwpe, void encoder_forward(floatX* out, const int* inp, const floatX* wte, const floatX* wpe, - int B, int T, int C, cudaStream_t stream) { + int use_kv, int kv_offset, int B, int T, int C, cudaStream_t stream) { NVTX_RANGE_FN(); const int block_size = 256; - const int N = B * T * C; + if (use_kv) { + inp += kv_offset; + wpe += kv_offset * C; + out += kv_offset * C; + } + const int N = B * (use_kv ? 1 : T) * C; const int grid_size = CEIL_DIV(N, (int)(block_size * x128::size)); - encoder_forward_kernel3<<>>(out, inp, wte, wpe, B, T, C); + encoder_forward_kernel3<<>>(out, inp, wte, wpe, B, T, C, use_kv); cudaCheck(cudaGetLastError()); } diff --git a/llmc/layernorm.cuh b/llmc/layernorm.cuh index 9777d0658..2186e9e8d 100644 --- a/llmc/layernorm.cuh +++ b/llmc/layernorm.cuh @@ -432,11 +432,11 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with // similar to `fused_residual_forward5` void layernorm_forward(floatX* out, float* mean, float* rstd, floatX* inp, const floatX* weight, const floatX* bias, - int B, int T, int C, cudaStream_t stream) { + int use_kv, int kv_offset, int B, int T, int C, cudaStream_t stream) { NVTX_RANGE_FN(); const int block_size = 256; int block_y = block_size / WARP_SIZE; - const int N = B * T; + const int N = B * (use_kv ? 1 : T); const int grid_size = CEIL_DIV(N, block_y); size_t smem = (2 + block_y) * C * sizeof(floatX); diff --git a/train_gpt2.cu b/train_gpt2.cu index 99b3c5eb8..dffc125bf 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -341,6 +341,8 @@ typedef struct { // todo - if other functions need cpu scratch buffers in the future, reuse as generic scratch? int* workload_indices; // encoder_backward, B*T*num_c_groups (int) int4* bucket_info; // encoder_backward, B*T*num_c_groups (int4) - size for worst case + int use_kv; // whether to use KV cache in attention + int kv_offset; } GPT2; void gpt2_init_common(GPT2 *model) { @@ -369,6 +371,8 @@ void gpt2_init_common(GPT2 *model) { model->use_master_weights = 1; // safe default: do keep master weights in fp32 model->recompute = 1; // good default: recompute gelu but not layernorm model->gelu_fusion = 0; //deviceProp.major >= 9 ? 2 : 0; // default: off for now (default must match main()) + model->use_kv = 0; // default: no KV cache + model->kv_offset = 0; } void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) { @@ -647,10 +651,10 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { // forward pass ParameterTensors params = model->params; // for brevity ActivationTensors acts = model->acts; - encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, B, T, C, main_stream); // encoding goes into residual[0] + encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, model->use_kv, model->kv_offset, B, T, C, main_stream); // encoding goes into residual[0] // first layernorm isn't fused - layernorm_forward((model->recompute < 2) ? acts.ln1 : acts.lnf, acts.ln1_mean, acts.ln1_rstd, acts.encoded, params.ln1w, params.ln1b, B, T, C, main_stream); + layernorm_forward((model->recompute < 2) ? acts.ln1 : acts.lnf, acts.ln1_mean, acts.ln1_rstd, acts.encoded, params.ln1w, params.ln1b, model->use_kv, model->kv_offset, B, T, C, main_stream); for (int l = 0; l < L; l++) { NvtxRange layer_range("Layer", l); @@ -1397,8 +1401,8 @@ void error_usage() { // main training loop int main(int argc, char *argv[]) { // read in the (optional) command line arguments - const char* train_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_train.bin"; - const char* val_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_val.bin"; + const char* train_data_pattern = "/hdd/llmc/fineweb/bin/fineweb_train_*.bin"; + const char* val_data_pattern = "/hdd/llmc/fineweb/bin/fineweb_val_*.bin"; const char* load_filename = "gpt2_124M_bf16.bin"; // bf16 weights of the model const char* lr_scheduler_type = "cosine"; const char* output_log_dir = NULL; @@ -1415,9 +1419,9 @@ int main(int argc, char *argv[]) { float weight_decay = 0.0f; float skip_update_lossz = 0.0f; // skip update if loss goes above this in zscore float skip_update_gradz = 0.0f; // skip update if grad_norm goes above this in zscore - int val_loss_every = 20; // every how many steps do we eval validation loss? + int val_loss_every = 0; // every how many steps do we eval validation loss? int val_max_steps = 20; // how many batches max do we eval for validation loss? - int sample_every = 20; // every how many steps to do inference? + int sample_every = 1; // every how many steps to do inference? int genT = 64; // number of steps of inference we will do int overfit_single_batch = 0; // useful for debugging, 1 = only load a single data batch once int max_steps = -1; @@ -1693,7 +1697,7 @@ int main(int argc, char *argv[]) { int last_step = step == train_num_batches; // once in a while estimate the validation loss (all processes collaborate) - if (step % val_loss_every == 0 || last_step) { + if (val_loss_every > 0 && (step % val_loss_every == 0 || last_step)) { NvtxRange validation_range("validation"); float val_loss = 0.0f; dataloader_reset(&val_loader); @@ -1727,8 +1731,9 @@ int main(int argc, char *argv[]) { } // once in a while do model inference to print generated text (only rank 0) + // TODO(gordicaleksa): tmp set step to >= 0 to always generate if (multi_gpu_config.process_rank == 0 && sample_every > 0 && - (step > 0 && (step % sample_every) == 0 || last_step)) { + (step >= 0 && (step % sample_every) == 0 || last_step)) { NvtxRange generation_range("generation"); unsigned long long sample_rng_state = 1337; // fill up gen_tokens with the <|endoftext|> token, which kicks off the generation @@ -1738,6 +1743,7 @@ int main(int argc, char *argv[]) { } // now sample from the model autoregressively printf("generating:\n---\n"); + model.use_kv = 1; // we need to use the KV cache for generation for (int t = 1; t < genT; t++) { NvtxRange generation_range("Generation step", t); // we try not to be too wasteful for inference by not calculating all of B,T @@ -1747,6 +1753,7 @@ int main(int argc, char *argv[]) { // on cuDNN 9.2.1 with cuDNN FrontEnd 1.5.2, T >= 256 seems bit-for-bit identical // (but even if it wasn't fully identical that's probably not the end of the world) // note this is still somewhat wasteful because we don't have a KV cache! + model.kv_offset = t - 1; gpt2_forward(&model, gen_tokens, 1, CEIL_DIV(t, min(T,256)) * min(T,256)); // get the V-dimensional vector probs[0, t-1, :] floatX* logits = model.acts.output + (t - 1) * model.config.padded_vocab_size; @@ -1764,6 +1771,7 @@ int main(int argc, char *argv[]) { safe_printf(token_str); fflush(stdout); } + model.use_kv = 0; // don't use the KV cache outside of generation printf("\n---\n"); }