Skip to content

Commit

Permalink
refactor gpt2 gpt3 descriptor to keep legacy behavior of -e, be stric…
Browse files Browse the repository at this point in the history
…t everywhere, remove interpolation for now
  • Loading branch information
karpathy committed Jul 15, 2024
1 parent 4e629e2 commit 64db3de
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 78 deletions.
11 changes: 11 additions & 0 deletions llmc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,4 +212,15 @@ extern inline int find_max_step(const char* output_log_dir) {
return max_step;
}

extern inline bool ends_with_bin(const char* str) {
// checks if str ends with ".bin". could be generalized in the future.
if (str == NULL) { return false; }
size_t len = strlen(str);
const char* suffix = ".bin";
size_t suffix_len = strlen(suffix);
if (len < suffix_len) { return false; }
bool suffix_matches = strncmp(str + len - suffix_len, suffix, suffix_len) == 0;
return suffix_matches;
}

#endif
145 changes: 67 additions & 78 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
#include <sys/types.h>
// ----------- CPU utilities -----------
// defines: fopenCheck, freadCheck, fcloseCheck, fseekCheck, mallocCheck
// defines: create_dir_if_not_exists, find_max_step
// defines: create_dir_if_not_exists, find_max_step, ends_with_bin
#include "llmc/utils.h"
// defines: tokenizer_init, tokenizer_decode, tokenizer_free
#include "llmc/tokenizer.h"
Expand Down Expand Up @@ -375,9 +375,6 @@ typedef struct {

void gpt2_init_common(GPT2 *model) {
// common inits outside of the model weights
// the weights are initialized either in:
// - gpt2_build_from_checkpoint() if loading from a checkpoint
// - gpt2_build_from_random() if starting from scratch
// memory lazily initialized in forward()
model->acts_memory = NULL;
model->inputs = NULL;
Expand Down Expand Up @@ -494,90 +491,71 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
cudaCheck(cudaDeviceSynchronize());
}

void gpt2_setup(GPT2Config* config, const char* depth_str) {
void gpt2_set_hyperparameters(GPT2Config* config, const char* depth_str) {
int depth = atoi(depth_str);
assert(depth > 0); // atoi returns 0 if not a number
int channels, num_heads;
switch(depth) {
case 12:
case 24:
case 36:
case 48:
break;
default:
printf("Specified depth %d does not correspond to an official GPT2 model size.\n", depth);
printf(" Generating an interpolated model.\n");
}

if (depth <= 6) { channels = 384; num_heads = 6; } // (unofficial) gpt2-tiny (30M)
else if (depth <= 12) { channels = 768; num_heads = 12; } // gpt2 (124M)
else if (depth <= 24) { channels = 1024; num_heads = 16; } // gpt2-medium (350M)
else if (depth <= 36) { channels = 1280; num_heads = 20; } // gpt2-large (774M)
else if (depth <= 48) { channels = 1600; num_heads = 25; } // gpt2-xl (1558M)
else if (depth <= 60) { channels = 1920; num_heads = 30; } // (unofficial) 2.7B
else if (depth <= 72) { channels = 2880; num_heads = 30; } // (unofficial) 7.3B
else if (depth <= 84) { channels = 3456; num_heads = 36; } // (unofficial) 12.2B
else { fprintf(stderr, "Unsupported depth %d for now\n", depth); exit(EXIT_FAILURE); }
if (depth == 6) { channels = 384; num_heads = 6; } // (unofficial) gpt2-tiny (30M)
else if (depth == 12) { channels = 768; num_heads = 12; } // gpt2 (124M)
else if (depth == 24) { channels = 1024; num_heads = 16; } // gpt2-medium (350M)
else if (depth == 36) { channels = 1280; num_heads = 20; } // gpt2-large (774M)
else if (depth == 48) { channels = 1600; num_heads = 25; } // gpt2-xl (1558M)
else if (depth == 60) { channels = 1920; num_heads = 30; } // (unofficial) 2.7B
else if (depth == 72) { channels = 2880; num_heads = 30; } // (unofficial) 7.3B
else if (depth == 84) { channels = 3456; num_heads = 36; } // (unofficial) 12.2B
else { fprintf(stderr, "Unsupported GPT-2 depth: %d\n", depth); exit(EXIT_FAILURE); }
config->num_layers = depth;
config->channels = channels;
config->num_heads = num_heads;
config->max_seq_len = 1024;
}

void gpt3_setup(GPT2Config* config, const char* channels_str) {
// note: we do not quite get the same model as gpt3, because we consistently use
// dense attention, instead of alternating banded attention
void gpt3_set_hyperparameters(GPT2Config* config, const char* channels_str) {
// we use channels instead of depth for GPT-3 because GPT-3 model depths are not one-to-one
// note that our models are not necessarily identical to GPT-3 because
// we use dense attention, not the alternating dense/banded attention of GPT-3
int channels = atoi(channels_str);

switch(channels) {
case 768:
case 1024:
case 1536:
case 2048:
case 2560:
case 4096:
case 5140:
case 12288:
break;
default:
printf("Specified channels %d do not correspond to an official GPT3 model size.\n", channels);
printf(" Generating an interpolated model.\n");
}

assert(channels > 0); // atoi returns 0 if not a number
int depth, head_size;
if (channels <= 384) { depth = 6; head_size = 64; } // (unofficial) gpt3-tiny (31M)
else if (channels <= 768) { depth = 12; head_size = 64; } // gpt3-small (125M)
else if (channels <= 1024) { depth = 24; head_size = 64; } // gpt3-medium (350M)
else if (channels <= 1536) { depth = 24; head_size = 96; } // gpt3-large (760M)
else if (channels <= 2048) { depth = 24; head_size = 128; } // gpt3-xl (1.3B) [heads fixed]
else if (channels <= 2560) { depth = 32; head_size = 80; } // gpt3-2.7B
else if (channels <= 4096) { depth = 32; head_size = 128; } // gpt3-6.7B
else if (channels <= 5140) { depth = 40; head_size = 128; } // gpt3-13B
else if (channels <= 12288) { depth = 96; head_size = 128; } // gpt3 (175B)
else { fprintf(stderr, "Unsupported channels %d for now\n", channels); exit(EXIT_FAILURE); }
if(channels % head_size != 0) {
fprintf(stderr, "Number of channels %d incompatible with head size %d\n", channels, head_size);
fprintf(stderr, " The next valid number of channels is %d\n", CEIL_DIV(channels, head_size) * head_size);
exit(EXIT_FAILURE);
}
if (channels == 384) { depth = 6; head_size = 64; } // (unofficial) gpt3-tiny (31M)
else if (channels == 768) { depth = 12; head_size = 64; } // gpt3-small (125M)
else if (channels == 1024) { depth = 24; head_size = 64; } // gpt3-medium (350M)
else if (channels == 1536) { depth = 24; head_size = 96; } // gpt3-large (760M)
else if (channels == 2048) { depth = 24; head_size = 128; } // gpt3-xl (1.3B) [heads fixed]
else if (channels == 2560) { depth = 32; head_size = 80; } // gpt3-2.7B
else if (channels == 4096) { depth = 32; head_size = 128; } // gpt3-6.7B
else if (channels == 5140) { depth = 40; head_size = 128; } // gpt3-13B
else if (channels == 12288) { depth = 96; head_size = 128; } // gpt3 (175B)
else { fprintf(stderr, "Unsupported GPT-3 channels: %d\n", channels); exit(EXIT_FAILURE); }
assert(channels % head_size == 0);
config->num_layers = depth;
config->channels = channels;
config->num_heads = channels / head_size;
config->max_seq_len = 2048;
config->max_seq_len = 2048; // NOTE: GPT-3 uses context length of 2048 tokens, up from 1024 in GPT-2
}

void gpt2_build_from_random(GPT2 *model, const char* config) {
// init random (training from scratch)
const char* cfg = strchr(config, ':') + 1;
if(strncmp(config, "gpt2", 4) == 0) {
gpt2_setup(&model->config, cfg);
} else if(strncmp(config, "gpt3", 4) == 0) {
gpt3_setup(&model->config, cfg);
void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) {
// The model descriptor can be:
// - legacy format "dX", where X is number, e.g. "d12". This creates GPT-2 model with 12 layers.
// - new explicit format "gpt2:dX", same as above, e.g. "gpt2:d48" for GPT-2 with 48 layers.
// - "gpt3:cX", where X is now the channel count, e.g. "gpt3:c768" is the smallest GPT-3 model.

// check the valid prexies and dispatch to the right setup function
assert(descriptor != NULL);
size_t len = strlen(descriptor);
if (len > 1 && descriptor[0] == 'd') {
gpt2_set_hyperparameters(&model->config, descriptor + 1); // pass along the depth str without the 'd'
} else if (len > 6 && strncmp(descriptor, "gpt2:d", 6) == 0) {
gpt2_set_hyperparameters(&model->config, descriptor + 6); // pass along the depth str without the 'gpt2:d'
} else if (len > 6 && strncmp(descriptor, "gpt3:c", 6) == 0) {
gpt3_set_hyperparameters(&model->config, descriptor + 6); // pass along the channels str without the 'gpt3:c'
} else {
fprintf(stderr, "Unsupported model type %s\n", config); exit(EXIT_FAILURE);
fprintf(stderr, "Unsupported model descriptor: %s\n", descriptor); exit(EXIT_FAILURE);
}

// both GPT-2 and GPT-3 use the same tokenizer with 50257 tokens
model->config.vocab_size = 50257;
model->config.padded_vocab_size = 50304; // padded to 128
model->config.padded_vocab_size = 50304; // padded to 128 for CUDA kernel efficiency

// fill in all the parameter tensor dimensions and types
fill_in_parameter_sizes(model->param_elements, model->param_sizeof, model->config);
Expand Down Expand Up @@ -1415,7 +1393,7 @@ void error_usage() {
// file system input / output
fprintf(stderr, " -i <string> train data filename pattern (default = dev/data/tinyshakespeare/tiny_shakespeare_train.bin)\n");
fprintf(stderr, " -j <string> val data filename pattern (default = dev/data/tinyshakespeare/tiny_shakespeare_val.bin)\n");
fprintf(stderr, " -e <string> input from model at this filename (default = gpt2_124M_bf16.bin)\n");
fprintf(stderr, " -e <string> input .bin filename or descriptor, see code comments as docs. (default = gpt2_124M_bf16.bin)\n");
fprintf(stderr, " -o <string> output log dir (default = NULL, no logging)\n");
fprintf(stderr, " -n <int> write optimization checkpoints every how many steps? (default 0, don't)\n");
fprintf(stderr, " -nk <int> max number of checkpoints to keep in the directory, removing old ones (0 = disable, default)\n");
Expand Down Expand Up @@ -1616,22 +1594,33 @@ int main(int argc, char *argv[]) {
// build the GPT-2 model
GPT2 model;
gpt2_init_common(&model);
// if load_filename is of the form "dX" where X is an integer (e.g. d12), then we build
// a random model with the depth of the model specified by X (e.g. 12). otherwise interpret
// this variable as a checkpoint filename, and load that checkpoint
assert(strlen(load_filename) >= 2);
if (resuming == 1) {
// if `-y 1` was set, then we are resuming from the latest checkpoint
gpt2_build_from_checkpoint(&model, filename_buffer);
} else if (strchr(load_filename, ':') != nullptr) {
gpt2_build_from_random(&model, load_filename);
} else {
} else if (ends_with_bin(load_filename)) {
// otherwise, if this is a .bin file, we assume it's a model, let's init from it
gpt2_build_from_checkpoint(&model, load_filename);
} else {
// if it's not .bin, it could be a "special descriptor". This descriptor is used to
// construct GPT-2 / GPT-3 models in a convenient format. See the function for docs.
gpt_build_from_descriptor(&model, load_filename);
}
// cross-check the desired sequence length T with the model's max sequence length
if (T != model.config.max_seq_len) {
printf0("Warning: sequence length T=%d (set with -t) is not equal to model's max_seq_len=%d\n",
T, model.config.max_seq_len);
printf0("HINT: If you're training a GPT-2 use -t 1024. If GPT-3, use -t 2048.\n");
printf0("This could in principle be ok if T <= max_seq_length, but this is a major footgun...\n");
printf0("Failing catastrophically for now.\n");
exit(EXIT_FAILURE);
}
// in any case, this must be true or we'd index beyond the model's wpe (position embedding table)
assert(T <= model.config.max_seq_len);

model.use_master_weights = use_master_weights;
model.gelu_fusion = gelu_fusion;
model.recompute = recompute;
printf0("| weight init method | %-50s |\n", resuming == 1 ? "intermediate checkpoint" : (load_filename[0] == 'd' ? "random" : "OpenAI's GPT-2 checkpoint"));
printf0("| weight init method | %-50s |\n", resuming == 1 ? "intermediate checkpoint" : load_filename);
printf0("| max_sequence_length T | %-50d |\n", model.config.max_seq_len);
printf0("| vocab_size V | %-50d |\n", model.config.vocab_size);
printf0("| padded_vocab_size Vp | %-50d |\n", model.config.padded_vocab_size);
Expand Down

0 comments on commit 64db3de

Please sign in to comment.