Skip to content

Commit

Permalink
Add conditional kv computation to encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed Jul 22, 2024
1 parent 0585227 commit 1c83284
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 17 deletions.
19 changes: 12 additions & 7 deletions llmc/encoder.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ In the backward pass, the gradients flow to both, handled by different kernels

__global__ void encoder_forward_kernel3(floatX* out,
const int* inp, const floatX* wte, const floatX* wpe,
int B, int T, int C) {
int B, int T, int C, int use_kv) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;
int N = B * T * C;
int N = B * (use_kv ? 1 : T) * C;
if (idx >= N) { return; }

int bt = idx / C;
int b = bt / T;
int t = bt % T;
int b = bt / (use_kv ? 1 : T);
int t = use_kv ? 0 : bt % T;
int c = idx % C;

int ix = inp[b * T + t];
Expand Down Expand Up @@ -156,12 +156,17 @@ __global__ void wpe_backward_kernel(floatX* dwpe,

void encoder_forward(floatX* out,
const int* inp, const floatX* wte, const floatX* wpe,
int B, int T, int C, cudaStream_t stream) {
int use_kv, int kv_offset, int B, int T, int C, cudaStream_t stream) {
NVTX_RANGE_FN();
const int block_size = 256;
const int N = B * T * C;
if (use_kv) {
inp += kv_offset;
wpe += kv_offset * C;
out += kv_offset * C;
}
const int N = B * (use_kv ? 1 : T) * C;
const int grid_size = CEIL_DIV(N, (int)(block_size * x128::size));
encoder_forward_kernel3<<<grid_size, block_size, 0, stream>>>(out, inp, wte, wpe, B, T, C);
encoder_forward_kernel3<<<grid_size, block_size, 0, stream>>>(out, inp, wte, wpe, B, T, C, use_kv);
cudaCheck(cudaGetLastError());
}

Expand Down
4 changes: 2 additions & 2 deletions llmc/layernorm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -432,11 +432,11 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with
// similar to `fused_residual_forward5`
void layernorm_forward(floatX* out, float* mean, float* rstd,
floatX* inp, const floatX* weight, const floatX* bias,
int B, int T, int C, cudaStream_t stream) {
int use_kv, int kv_offset, int B, int T, int C, cudaStream_t stream) {
NVTX_RANGE_FN();
const int block_size = 256;
int block_y = block_size / WARP_SIZE;
const int N = B * T;
const int N = B * (use_kv ? 1 : T);
const int grid_size = CEIL_DIV(N, block_y);
size_t smem = (2 + block_y) * C * sizeof(floatX);

Expand Down
24 changes: 16 additions & 8 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ typedef struct {
// todo - if other functions need cpu scratch buffers in the future, reuse as generic scratch?
int* workload_indices; // encoder_backward, B*T*num_c_groups (int)
int4* bucket_info; // encoder_backward, B*T*num_c_groups (int4) - size for worst case
int use_kv; // whether to use KV cache in attention
int kv_offset;
} GPT2;

void gpt2_init_common(GPT2 *model) {
Expand Down Expand Up @@ -369,6 +371,8 @@ void gpt2_init_common(GPT2 *model) {
model->use_master_weights = 1; // safe default: do keep master weights in fp32
model->recompute = 1; // good default: recompute gelu but not layernorm
model->gelu_fusion = 0; //deviceProp.major >= 9 ? 2 : 0; // default: off for now (default must match main())
model->use_kv = 0; // default: no KV cache
model->kv_offset = 0;
}

void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) {
Expand Down Expand Up @@ -647,10 +651,10 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
// forward pass
ParameterTensors params = model->params; // for brevity
ActivationTensors acts = model->acts;
encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, B, T, C, main_stream); // encoding goes into residual[0]
encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, model->use_kv, model->kv_offset, B, T, C, main_stream); // encoding goes into residual[0]

// first layernorm isn't fused
layernorm_forward((model->recompute < 2) ? acts.ln1 : acts.lnf, acts.ln1_mean, acts.ln1_rstd, acts.encoded, params.ln1w, params.ln1b, B, T, C, main_stream);
layernorm_forward((model->recompute < 2) ? acts.ln1 : acts.lnf, acts.ln1_mean, acts.ln1_rstd, acts.encoded, params.ln1w, params.ln1b, model->use_kv, model->kv_offset, B, T, C, main_stream);

for (int l = 0; l < L; l++) {
NvtxRange layer_range("Layer", l);
Expand Down Expand Up @@ -1397,8 +1401,8 @@ void error_usage() {
// main training loop
int main(int argc, char *argv[]) {
// read in the (optional) command line arguments
const char* train_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_train.bin";
const char* val_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_val.bin";
const char* train_data_pattern = "/hdd/llmc/fineweb/bin/fineweb_train_*.bin";
const char* val_data_pattern = "/hdd/llmc/fineweb/bin/fineweb_val_*.bin";
const char* load_filename = "gpt2_124M_bf16.bin"; // bf16 weights of the model
const char* lr_scheduler_type = "cosine";
const char* output_log_dir = NULL;
Expand All @@ -1415,9 +1419,9 @@ int main(int argc, char *argv[]) {
float weight_decay = 0.0f;
float skip_update_lossz = 0.0f; // skip update if loss goes above this in zscore
float skip_update_gradz = 0.0f; // skip update if grad_norm goes above this in zscore
int val_loss_every = 20; // every how many steps do we eval validation loss?
int val_loss_every = 0; // every how many steps do we eval validation loss?
int val_max_steps = 20; // how many batches max do we eval for validation loss?
int sample_every = 20; // every how many steps to do inference?
int sample_every = 1; // every how many steps to do inference?
int genT = 64; // number of steps of inference we will do
int overfit_single_batch = 0; // useful for debugging, 1 = only load a single data batch once
int max_steps = -1;
Expand Down Expand Up @@ -1693,7 +1697,7 @@ int main(int argc, char *argv[]) {
int last_step = step == train_num_batches;

// once in a while estimate the validation loss (all processes collaborate)
if (step % val_loss_every == 0 || last_step) {
if (val_loss_every > 0 && (step % val_loss_every == 0 || last_step)) {
NvtxRange validation_range("validation");
float val_loss = 0.0f;
dataloader_reset(&val_loader);
Expand Down Expand Up @@ -1727,8 +1731,9 @@ int main(int argc, char *argv[]) {
}

// once in a while do model inference to print generated text (only rank 0)
// TODO(gordicaleksa): tmp set step to >= 0 to always generate
if (multi_gpu_config.process_rank == 0 && sample_every > 0 &&
(step > 0 && (step % sample_every) == 0 || last_step)) {
(step >= 0 && (step % sample_every) == 0 || last_step)) {
NvtxRange generation_range("generation");
unsigned long long sample_rng_state = 1337;
// fill up gen_tokens with the <|endoftext|> token, which kicks off the generation
Expand All @@ -1738,6 +1743,7 @@ int main(int argc, char *argv[]) {
}
// now sample from the model autoregressively
printf("generating:\n---\n");
model.use_kv = 1; // we need to use the KV cache for generation
for (int t = 1; t < genT; t++) {
NvtxRange generation_range("Generation step", t);
// we try not to be too wasteful for inference by not calculating all of B,T
Expand All @@ -1747,6 +1753,7 @@ int main(int argc, char *argv[]) {
// on cuDNN 9.2.1 with cuDNN FrontEnd 1.5.2, T >= 256 seems bit-for-bit identical
// (but even if it wasn't fully identical that's probably not the end of the world)
// note this is still somewhat wasteful because we don't have a KV cache!
model.kv_offset = t - 1;
gpt2_forward(&model, gen_tokens, 1, CEIL_DIV(t, min(T,256)) * min(T,256));
// get the V-dimensional vector probs[0, t-1, :]
floatX* logits = model.acts.output + (t - 1) * model.config.padded_vocab_size;
Expand All @@ -1764,6 +1771,7 @@ int main(int argc, char *argv[]) {
safe_printf(token_str);
fflush(stdout);
}
model.use_kv = 0; // don't use the KV cache outside of generation
printf("\n---\n");
}

Expand Down

0 comments on commit 1c83284

Please sign in to comment.