Skip to content

Commit

Permalink
Merge pull request #705 from gordicaleksa/refactor_c
Browse files Browse the repository at this point in the history
Refactor C code
  • Loading branch information
karpathy authored Jul 30, 2024
2 parents ef12d1b + 16c990f commit 29aacba
Showing 1 changed file with 31 additions and 24 deletions.
55 changes: 31 additions & 24 deletions train_gpt2.c
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,36 @@ typedef struct {
float* losses; // (B, T)
} ActivationTensors;

void fill_in_activation_sizes(size_t* act_sizes, GPT2Config config, int B, int T) {
size_t C = config.channels;
size_t NH = config.num_heads;
size_t L = config.num_layers;
size_t Vp = config.padded_vocab_size;
act_sizes[0] = B * T * C; // encoded
act_sizes[1] = L * B * T * C; // ln1
act_sizes[2] = L * B * T; // ln1_mean
act_sizes[3] = L * B * T; // ln1_rstd
act_sizes[4] = L * B * T * 3 * C; // qkv
act_sizes[5] = L * B * T * C; // atty
act_sizes[6] = L * B * NH * T * T; // preatt
act_sizes[7] = L * B * NH * T * T; // att
act_sizes[8] = L * B * T * C; // attproj
act_sizes[9] = L * B * T * C; // residual2
act_sizes[10] = L * B * T * C; // ln2
act_sizes[11] = L * B * T; // ln2_mean
act_sizes[12] = L * B * T; // ln2_rstd
act_sizes[13] = L * B * T * 4 * C; // fch
act_sizes[14] = L * B * T * 4 * C; // fch_gelu
act_sizes[15] = L * B * T * C; // fcproj
act_sizes[16] = L * B * T * C; // residual3
act_sizes[17] = B * T * C; // lnf
act_sizes[18] = B * T; // lnf_mean
act_sizes[19] = B * T; // lnf_rstd
act_sizes[20] = B * T * Vp; // logits
act_sizes[21] = B * T * Vp; // probs
act_sizes[22] = B * T; // losses
}

float* malloc_and_point_activations(ActivationTensors* acts, size_t* act_sizes) {
size_t num_activations = 0;
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
Expand Down Expand Up @@ -678,7 +708,6 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {

// read in model from a checkpoint file
FILE *model_file = fopenCheck(checkpoint_path, "rb");
if (model_file == NULL) { printf("Error opening model file\n"); exit(1); }
int model_header[256];
freadCheck(model_header, sizeof(int), 256, model_file);
if (model_header[0] != 20240326) { printf("Bad magic model file\n"); exit(1); }
Expand Down Expand Up @@ -763,29 +792,7 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) {
model->batch_size = B;
model->seq_len = T;
// and now allocate the space
model->act_sizes[0] = B * T * C; // encoded
model->act_sizes[1] = L * B * T * C; // ln1
model->act_sizes[2] = L * B * T; // ln1_mean
model->act_sizes[3] = L * B * T; // ln1_rstd
model->act_sizes[4] = L * B * T * 3*C; // qkv
model->act_sizes[5] = L * B * T * C; // atty
model->act_sizes[6] = L * B * NH * T * T; // preatt
model->act_sizes[7] = L * B * NH * T * T; // att
model->act_sizes[8] = L * B * T * C; // attproj
model->act_sizes[9] = L * B * T * C; // residual2
model->act_sizes[10] = L * B * T * C; // ln2
model->act_sizes[11] = L * B * T; // ln2_mean
model->act_sizes[12] = L * B * T; // ln2_rstd
model->act_sizes[13] = L * B * T * 4*C; // fch
model->act_sizes[14] = L * B * T * 4*C; // fch_gelu
model->act_sizes[15] = L * B * T * C; // fcproj
model->act_sizes[16] = L * B * T * C; // residual3
model->act_sizes[17] = B * T * C; // lnf
model->act_sizes[18] = B * T; // lnf_mean
model->act_sizes[19] = B * T; // lnf_rstd
model->act_sizes[20] = B * T * Vp; // logits
model->act_sizes[21] = B * T * Vp; // probs
model->act_sizes[22] = B * T; // losses
fill_in_activation_sizes(model->act_sizes, model->config, B, T);
size_t num_activations = 0;
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
num_activations += model->act_sizes[i];
Expand Down

0 comments on commit 29aacba

Please sign in to comment.