Skip to content

Commit

Permalink
just pushing what i have. it's epsilon away from working sigh. basica…
Browse files Browse the repository at this point in the history
…lly at this point of where prints happen, gradients match. but once we backward attention, rope and repkv, gradients don't match. attention hasn't changed so that can't be wrong (?), so it's either repkv or rope. i have to go slower and double check the backward pass of both of these in detail. also had to introduce one more additional buffer for backward
  • Loading branch information
karpathy committed Sep 27, 2024
1 parent 28e4a7f commit 075e430
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 29 deletions.
58 changes: 58 additions & 0 deletions llmc/repkv.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,54 @@ __global__ void repkv_forward_kernel1(floatX* replicated_qkv,
replicated_qkv[idx_flat] = __ldcs(&gqa_qkv[inp_idx]);
}

__global__ void repkv_backward_kernel1(floatX* dinp, const floatX* dout,
int B, int N, int NH, int replicate_factor, int HD) {
// we have a single tensor dout of shapae of (B, N 3 * NH * HD)
// we want to reduce sum (for K and V) into (B, N, (NH + 2*(NH/replicate_factor)) * HD)
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= B * N * 3 * NH * HD) { return;}
int dout_idx = idx; // keep backup

// decode the dout index
int d = idx % HD;
idx /= HD;
int nh = idx % NH;
idx /= NH;
int c = idx % 3;
idx /= 3;
int n = idx % N;
int b = idx / N;

int dinp_idx;
int nh_total = NH + 2 * (NH / replicate_factor);

if (c == 0) {
dinp_idx = b * N * nh_total * HD + n * nh_total * HD + 0 * NH * HD + nh * HD + d;
dinp[dinp_idx] = __ldcs(&dout[dout_idx]);
} else if (c == 1) {
if (nh % replicate_factor == 0) {
float reduced_sum = 0;
for (int i = 0; i < replicate_factor; i++) {
reduced_sum += (float) __ldcs(&dout[dout_idx+HD*i]);
}

dinp_idx = b * N * nh_total * HD + n * nh_total * HD + 1 * NH * HD + (nh / replicate_factor) * HD + d;
dinp[dinp_idx] = reduced_sum;
}

} else {
if (nh % replicate_factor == 0) {
float reduced_sum = 0;
for (int i = 0; i < replicate_factor; i++) {
reduced_sum += (float) __ldcs(&dout[dout_idx+HD*i]);
}
dinp_idx = b * N * nh_total * HD + n * nh_total * HD + (NH * HD + (NH / replicate_factor) * HD) + (nh / replicate_factor) * HD + d;
dinp[dinp_idx] = reduced_sum;
}
}
}

// kernel launchers
void repkv_forward(floatX* out, const floatX* inp, int B, int T, int NH, int NH_KV, int HD, cudaStream_t stream) {
// NH = number of query heads, NH_KV = number of key and value heads, HD = head dimension
const int block_size = 128;
Expand All @@ -61,3 +109,13 @@ void repkv_forward(floatX* out, const floatX* inp, int B, int T, int NH, int NH_
}
cudaCheck(cudaGetLastError());
}

void repkv_backward(floatX* dinp, const floatX* dout,
const int B, const int T, const int NH, const int NH_KV, const int d) {
const int block_size = 128;
int total_threads = B * T * (3 * NH) * d;
int num_blocks = CEIL_DIV(total_threads, block_size);
int replicate_factor = NH / NH_KV;
repkv_backward_kernel1<<<num_blocks, block_size>>>(dinp, dout, B, T, NH, replicate_factor, d);
cudaCheck(cudaGetLastError());
}
41 changes: 38 additions & 3 deletions llmc/rope.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,44 @@ __global__ void rope_forward_kernel1(floatX *out, const floatX *inp, const float
int idx_bt = b * (T * 3 * n_head * head_dim) + t * (3 * n_head * head_dim);
int idx_bth = idx_bt + qkv * (n_head * head_dim) + h * head_dim;
int idxi = idx_bth + 2 * d; // index in the input
// fetch the input
float x_real = inp[idxi];
float x_imag = inp[idxi + 1];
// fetch the freqs_cis
int freqs_idx = t * head_dim + 2 * d;
float freqs_cos = freqs_cis[freqs_idx];
float freqs_sin = freqs_cis[freqs_idx + 1];
// fetch the input
float x_real = inp[idxi];
float x_imag = inp[idxi + 1];
// apply the rotation
out[idxi] = x_real * freqs_cos - x_imag * freqs_sin;
out[idxi + 1] = x_real * freqs_sin + x_imag * freqs_cos;
}

__global__ void rope_backward_inplace_kernel1(floatX *dinp, const floatX *dout, const floatX *freqs_cis, int B, int T, int n_head, int head_dim) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int head_dim_half = head_dim / 2;
if (idx >= B * T * 3 * n_head * head_dim_half) return;
// decode the qkv index early so we can early exit if it's a value index
int qkv = (idx / (n_head * head_dim_half)) % 3;
if (qkv == 2) return; // no-op for v
// decode the individual indices and get the input index
int b = idx / (T * 3 * n_head * head_dim_half);
int t = (idx / (3 * n_head * head_dim_half)) % T;
int h = (idx / head_dim_half) % n_head;
int d = idx % head_dim_half;
int idx_bt = b * (T * 3 * n_head * head_dim) + t * (3 * n_head * head_dim);
int idx_bth = idx_bt + qkv * (n_head * head_dim) + h * head_dim;
int idxi = idx_bth + 2 * d; // index in the input
// fetch the freqs_cis
int freqs_idx = t * head_dim + 2 * d;
float freqs_cos = freqs_cis[freqs_idx];
float freqs_sin = freqs_cis[freqs_idx + 1];
// backward
float dout_real = (float)dout[idxi];
float dout_imag = (float)dout[idxi + 1];
dinp[idxi] = dout_real * freqs_cos + dout_imag * freqs_sin;
dinp[idxi + 1] = -dout_real * freqs_sin + dout_imag * freqs_cos;
}

void rope_forward(floatX *out, const floatX *inp, const floatX *freqs_cis, int B, int T, int n_head, int head_dim, cudaStream_t stream) {
// the input and output to this kernel are (B, T, 3, NH, HD) where the 3 is q,k,v
// we are going to launch exactly one thread per element of the output,
Expand All @@ -81,3 +107,12 @@ void rope_forward(floatX *out, const floatX *inp, const floatX *freqs_cis, int B
rope_forward_kernel1<<<num_blocks, block_size, 0, stream>>>(out, inp, freqs_cis, B, T, n_head, head_dim);
cudaCheck(cudaGetLastError());
}

void rope_backward_inplace(floatX *dinp, const floatX *dout, const floatX *freqs_cis, int B, int T, int n_head, int head_dim, cudaStream_t stream) {
// backward pass of forward, mirrors the forward kernel in setup and indexing
const int block_size = 128;
int total_threads = B * T * 3 * n_head * head_dim / 2;
int num_blocks = CEIL_DIV(total_threads, block_size);
rope_backward_inplace_kernel1<<<num_blocks, block_size, 0, stream>>>(dinp, dout, freqs_cis, B, T, n_head, head_dim);
cudaCheck(cudaGetLastError());
}
48 changes: 28 additions & 20 deletions train_llama3.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
#include "llmc/adamw.cuh"
// defines: global_norm_squared
#include "llmc/global_norm.cuh"
// defines: repkv_forward
// defines: repkv_forward, repkv_backward
#include "llmc/repkv.cuh"
// defines: precompute_freqs_cis, rope_forward
// defines: precompute_freqs_cis, rope_forward, rope_backward_inplace
#include "llmc/rope.cuh"
// defines: swiglu_forward, swiglu_backward
#include "llmc/swiglu.cuh"
Expand Down Expand Up @@ -197,7 +197,7 @@ void* malloc_and_point_parameters(ParameterTensors* params, size_t* param_elemen
return params_memory;
}

constexpr int NUM_ACTIVATION_TENSORS = 21;
constexpr int NUM_ACTIVATION_TENSORS = 22;
typedef struct {
floatX* encoded; // (B, T, C)
floatX* ln1; // (L, B, T, C)
Expand Down Expand Up @@ -234,6 +234,7 @@ typedef struct {
// some additional scratch buffers
floatX* scratch_bt4c; // (B, T, 4*C)
floatX* scratch_btc; // (B, T, C)
floatX* scratch_bt4c2; // (B, T, 4*C), for simplicify use this one for backward pass too, probably not needed
} ActivationTensors;


Expand Down Expand Up @@ -292,6 +293,7 @@ void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensor
tensors[18] = TENSOR_SPEC(data->output, B * T * max(qkv_channels, max(ffn_channels, max(NH*T, Vp))));
tensors[19] = TENSOR_SPEC(data->scratch_bt4c, B * T * ffn_channels);
tensors[20] = TENSOR_SPEC(data->scratch_btc, B * T * C);
tensors[21] = TENSOR_SPEC(data->scratch_bt4c2, B * T * ffn_channels);
}

void* malloc_and_point_activations(TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS]) {
Expand Down Expand Up @@ -839,7 +841,6 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int

// get the pointers of the weights for this layer
floatX* l_ln1w = params.ln1w + l * C;
floatX* l_ln1b = params.ln1b + l * C;
floatX* l_qkvw = params.qkvw + l * qkv_channels * C;
floatX* l_attprojw = params.attprojw + l * C * C;
floatX* l_ln2w = params.ln2w + l * C;
Expand All @@ -860,13 +861,11 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
floatX* dl_fcprojb = grads.fcprojb + l * C;
// get the pointers of the activations for this layer
floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf;
float* l_ln1_mean = acts.ln1_mean + l * B * T;
float* l_ln1_rstd = acts.ln1_rstd + l * B * T;
floatX* l_qkvr = acts.qkvr + l * B * T * qkv_channels;
floatX* l_atty = acts.atty + l * B * T * C;
floatX* l_residual2 = acts.residual2 + l * B * T * C;
floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf;
float* l_ln2_mean = acts.ln2_mean + l * B * T;
float* l_ln2_rstd = acts.ln2_rstd + l * B * T;
floatX* l_fch_pre_gelu = acts.fch + l * B * T * ffn_channels;
floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * ffn_channels_post_gelu : acts.fch_gelu;
Expand All @@ -875,6 +874,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
// re-using this memory in every Transformer block as we calculate backward pass

floatX* dl_bt4c = (floatX*)model->acts.scratch_bt4c;
floatX* dl_bt4c2 = (floatX*)model->acts.scratch_bt4c2; // same size as dl_bt4c, just a second buffer

// start the backward pass for this layer
if(model->recompute >= 1) {
Expand All @@ -886,13 +886,16 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
// backward the 2nd matmul of MLP
matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_gelu, l_fcprojw, scratchF, B, T, ffn_channels_post_gelu, C, main_stream);
// backward the swiglu here, use scratchX to hold the grad because SwiGLU can't be inplace
swiglu_backward(scratchX, dl_bt4c, l_fch_pre_gelu, B, T, ffn_channels_post_gelu, main_stream);
swiglu_backward(dl_bt4c2, dl_bt4c, l_fch_pre_gelu, B, T, ffn_channels_post_gelu, main_stream);
// backward the 1st matmul of MLP
if(model->recompute >= 2) {
// same as gelu above, l_ln1 and l_ln2 are just buffers if recompute >= 2, recompute them here on demand
rmsnorm_forward(l_ln2, l_ln2_rstd, l_residual2, l_ln2w, B, T, C, main_stream);
}
matmul_backward(dl_btc, dl_fcw, dl_fcb, scratchX, l_ln2, l_fcw, scratchF, B, T, C, ffn_channels, main_stream);
matmul_backward(dl_btc, dl_fcw, dl_fcb, dl_bt4c2, l_ln2, l_fcw, scratchF, B, T, C, ffn_channels, main_stream);
// rmsnorm backward does += to the dresidual, so it correctly accumulates grad from the MLP block above
rmsnorm_backward(dresidual, dl_ln2w, scratchF, dl_btc, l_residual2, l_ln2w, l_ln2_rstd, B, T, C, main_stream);
matmul_backward(dl_btc, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, scratchF, B, T, C, C, main_stream);

// ------------------------------------------------------------------------
// DEBUGGING: we only work until this point right now, so exit here
Expand All @@ -905,19 +908,18 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
}
// write to .bin file
// move output to cpu
floatX* cpu_output = (floatX*)mallocCheck(B*T*C * sizeof(floatX));
cudaCheck(cudaMemcpy(cpu_output, output, B*T*C * sizeof(floatX), cudaMemcpyDeviceToHost));
// int sz = B*T*qkv_channels; //B*T*C;
int sz = B*T*C;
floatX* cpu_output = (floatX*)mallocCheck(sz * sizeof(floatX));
cudaCheck(cudaMemcpy(cpu_output, output, sz * sizeof(floatX), cudaMemcpyDeviceToHost));
FILE* f = fopen("out.bin", "wb");
fwrite(cpu_output, sizeof(floatX), B*T*C, f);
fwrite(cpu_output, sizeof(floatX), sz, f);
fclose(f);
exit(0);
// ------------------------------------------------------------------------

// layernorm backward does += to the dresidual, so it correctly accumulates grad from the MLP block above
layernorm_backward(dresidual, dl_ln2w, dl_ln2b, scratchF, dl_btc, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C, main_stream);
matmul_backward(dl_btc, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, scratchF, B, T, C, C, main_stream);

#ifdef ENABLE_CUDNN
printf("cuDNN path TODO\n"); exit(0);
float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor
attention_backward_cudnn(dl_bt4c, dl_btc, l_qkvr, l_atty, (float*)l_att, B, T, NH, C, main_stream);
#else
Expand All @@ -927,13 +929,19 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
floatX* buffer_b = l_fch_pre_gelu; // this is B x T x 4C, so even larger than what we need
attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream);
#endif
// backward rope (this can be done in-place)
rope_backward_inplace(dl_bt4c, dl_bt4c, model->freqs_cis, B, T, NH, hd, main_stream);
// backward repkv (use scratchX as gradient buffer here)
repkv_backward(dl_bt4c2, dl_bt4c, B, T, NH, n_kv_head, hd);

// <--- here the gradients don't match, so there is an issue in between

// backward QKV projection
if(model->recompute >= 2) {
layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C, main_stream);
rmsnorm_forward(l_ln1, l_ln1_rstd, residual, l_ln1w, B, T, C, main_stream);
}
// QKV parameter gradients
matmul_backward(dl_btc, dl_qkvw, dl_qkvb, dl_bt4c, l_ln1, l_qkvw, scratchF, B, T, C, 3 * C, main_stream);
// layernorm backward does += to dresidual, so it correctly accumulates gradient for the Attention block above
layernorm_backward(dresidual, dl_ln1w, dl_ln1b, scratchF, dl_btc, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C, main_stream);
matmul_backward(dl_btc, dl_qkvw, dl_qkvb, dl_bt4c2, l_ln1, l_qkvw, scratchF, B, T, C, qkv_channels, main_stream);
rmsnorm_backward(dresidual, dl_ln1w, scratchF, dl_btc, residual, l_ln1w, l_ln1_rstd, B, T, C, main_stream);

// Accumulate gradients from this layer in a background stream.
if(last_step) {
Expand Down
14 changes: 8 additions & 6 deletions train_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,12 @@ def forward(self, x, freqs_cis=None, start_pos=None, mask=None):
att = F.softmax(scores.float(), dim=-1).type_as(q)
y = att @ v # (B, NH, T, T) x (B, NH, T, HD) -> (B, NH, T, HD)
y = y.transpose(1, 2).contiguous().view(B, T, C)

DEBUG_POINT = y.detach()
DEBUG_POINT = DEBUG_POINT.requires_grad_(True)
self.DEBUG_POINT = DEBUG_POINT
y = DEBUG_POINT

y = self.c_proj(y)
return y

Expand Down Expand Up @@ -234,11 +240,7 @@ def __init__(self, config):

def forward(self, x, freqs_cis=None, start_pos=None, mask=None):
x = x + self.attn(self.ln_1(x), freqs_cis, start_pos, mask)
MLP_INPUT = self.ln_2(x)
MLP_INPUT = MLP_INPUT.detach()
MLP_INPUT.requires_grad = True
self.MLP_INPUT = MLP_INPUT
x = x + self.mlp(MLP_INPUT)
x = x + self.mlp(self.ln_2(x))
return x

# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -1260,7 +1262,7 @@ def get_lr(it):

# ---------------------------------------------------------------------
# DEBUGGING: print first 32 elements of x
x = model.transformer.h[-1].MLP_INPUT.grad
x = model.transformer.h[-1].attn.DEBUG_POINT.grad
for i in range(32):
print("q[{}]: {:.8f}".format(i, x.view(-1)[i].item()))
# write to .bin file
Expand Down

0 comments on commit 075e430

Please sign in to comment.