-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Restore from master weights (& allow restoring from a checkpoint of different precision) #702
Merged
karpathy
merged 6 commits into
karpathy:master
from
ademeure:restore_from_master_weights
Jul 30, 2024
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
f470fbd
Allow restoring weights from the master weights of a checkpoint (dete…
ademeure 5cae10f
make restoring from master weights actually work
ademeure 9781627
allow restoring from checkpoint of different precision
ademeure 2eabc22
simplify a little bit
ademeure 4d77ece
simplified further (don't need non-functional error checking...)
ademeure 52e6e0f
fix bug from merge (init_state set to false too late)
ademeure File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -311,7 +311,8 @@ typedef struct { | |
float* accumulated_mean_loss; // GPU buffer used to accumulate loss across micro-steps | ||
float* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost | ||
unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc. | ||
int use_master_weights; // keep master weights copy in float for optim update? 0|1 | ||
unsigned long long rng_state_last_update; // RNG before last gpt2_update() to re-round identically from master weights | ||
int use_master_weights; // keep master weights copy in float for optim update? 0|1 | ||
bool init_state; // set to true if master weights need to be initialized | ||
int gelu_fusion; // fuse gelu via cuBLASLt (0=none, 1=forward, 2=forward+backward) | ||
int recompute; // recompute gelu | layernorm forward during model backward? 0|1|2 | ||
|
@@ -438,7 +439,7 @@ void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) { | |
fcloseCheck(model_file); | ||
} | ||
|
||
void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { | ||
void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool resuming=false) { | ||
|
||
if (PRECISION_MODE == PRECISION_FP16) { | ||
// TODO for later perhaps, would require us dynamically converting the | ||
|
@@ -461,16 +462,20 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { | |
fprintf(stderr, "---> HINT: try to re-run `python train_gpt2.py`\n"); | ||
exit(EXIT_FAILURE); | ||
} | ||
if (PRECISION_MODE == PRECISION_BF16 && version != 5) { | ||
fprintf(stderr, "Precision is configured as BF16 but model at %s is not.\n", checkpoint_path); | ||
fprintf(stderr, "---> HINT: are you sure you're loading a _bf16.bin file?\n"); | ||
exit(EXIT_FAILURE); | ||
} | ||
if (PRECISION_MODE == PRECISION_FP32 && version != 3) { | ||
fprintf(stderr, "Precision is configured as FP32 but model at %s is not.\n", checkpoint_path); | ||
fprintf(stderr, "---> HINT: to turn on FP32 you have to compile like: `make train_gpt2cu PRECISION=FP32`\n"); | ||
fprintf(stderr, "---> HINT: are you sure you're loading a .bin file without any _bf16 in the name?\n"); | ||
exit(EXIT_FAILURE); | ||
|
||
// check if the precision mode matches the model (don't care if restoring from master weights!) | ||
if (!resuming || !model->use_master_weights) { | ||
if (PRECISION_MODE == PRECISION_BF16 && version != 5) { | ||
fprintf(stderr, "Precision is configured as BF16 but model at %s is not.\n", checkpoint_path); | ||
fprintf(stderr, "---> HINT: are you sure you're loading a _bf16.bin file?\n"); | ||
exit(EXIT_FAILURE); | ||
} | ||
if (PRECISION_MODE == PRECISION_FP32 && version != 3) { | ||
fprintf(stderr, "Precision is configured as FP32 but model at %s is not.\n", checkpoint_path); | ||
fprintf(stderr, "---> HINT: to turn on FP32 you have to compile like: `make train_gpt2cu PRECISION=FP32`\n"); | ||
fprintf(stderr, "---> HINT: are you sure you're loading a .bin file without any _bf16 in the name?\n"); | ||
exit(EXIT_FAILURE); | ||
} | ||
} | ||
|
||
// read in hyperparameters | ||
|
@@ -483,9 +488,11 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { | |
|
||
gpt2_allocate_weights(model); | ||
|
||
// read in all the parameters from file and copy them to device | ||
file_to_device(model->params_memory, model_file, model->num_parameters_bytes, | ||
IO_BUF_SIZE, main_stream); | ||
// read in all the parameters from file and copy them to device (if we need them) | ||
// if we are restoring with master weights, ignore these weights and init from master instead | ||
if (!resuming || !model->use_master_weights) { | ||
file_to_device(model->params_memory, model_file, model->num_parameters_bytes, IO_BUF_SIZE, main_stream); | ||
} | ||
fcloseCheck(model_file); | ||
|
||
// only return from this function once we are certain the params are ready on the GPU | ||
|
@@ -1008,7 +1015,8 @@ float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { | |
return grad_norm_cpu; | ||
} | ||
|
||
void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, float grad_scale, int t, MultiGpuConfig* multi_gpu_config) { | ||
void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, float grad_scale, int t, | ||
MultiGpuConfig* multi_gpu_config, bool init_from_master_only=false) { | ||
// update the model parameters using the AdamW optimizer | ||
// keep in mind that optimizer sharding (ZeRO-1) assigns different parameters to different GPUs | ||
// so we may not be responsible for the entire parameter tensor | ||
|
@@ -1028,6 +1036,10 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo | |
cudaCheck(cudaMemset(model->m_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(float))); | ||
cudaCheck(cudaMemset(model->v_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(float))); | ||
} | ||
|
||
// save RNG state at this point so we can round from master weights identically when restoring from a checkpoint | ||
model->rng_state_last_update = model->rng_state; | ||
|
||
// AdamW update | ||
// handle adamw for all the transformer blocks | ||
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { | ||
|
@@ -1064,13 +1076,17 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo | |
cudaCheck(cudaGetLastError()); | ||
} | ||
|
||
// ok finally call the kernel | ||
adamw_update(param_ptr, master_ptr, grad_ptr, | ||
m_ptr, v_ptr, | ||
shard.size, tensor.size, tensor.size, shard.size, num_layers, | ||
learning_rate, | ||
beta1, beta2, t, eps, wd, grad_scale, seed, main_stream); | ||
cudaCheck(cudaGetLastError()); | ||
if (init_from_master_only) { | ||
// when resuming training from a checkpoint with master weights (allows changing precision) | ||
init_from_master(param_ptr, master_ptr, shard.size, tensor.size, shard.size, num_layers, seed, main_stream); | ||
} else { | ||
// ok finally call the kernel to update the weights with AdamW | ||
adamw_update(param_ptr, master_ptr, grad_ptr, | ||
m_ptr, v_ptr, | ||
shard.size, tensor.size, tensor.size, shard.size, num_layers, | ||
learning_rate, | ||
beta1, beta2, t, eps, wd, grad_scale, seed, main_stream); | ||
} | ||
|
||
if (multi_gpu_config->zero_stage == 1) { | ||
#if MULTI_GPU | ||
|
@@ -1189,6 +1205,7 @@ void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) | |
state_header[10] = step; // step of the optimization | ||
// model rng state, start at 20 to leave some padding | ||
*((unsigned long long*)&state_header[20]) = model->rng_state; // random number generator state | ||
*((unsigned long long*)&state_header[22]) = model->rng_state_last_update; // last gpt2_update | ||
// dataloader state, start at 30 to leave some padding | ||
*((size_t*)&state_header[30]) = loader->current_shard_idx; // shard of the dataset | ||
*((size_t*)&state_header[32]) = loader->current_sample_idx; // position in shard | ||
|
@@ -1225,6 +1242,7 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename | |
int should_shuffle = state_header[5]; // shuffle state of the dataloader | ||
*step = state_header[10]; // step of the optimization | ||
model->rng_state = *((unsigned long long*)&state_header[20]); // random number generator state | ||
model->rng_state_last_update = *((unsigned long long*)&state_header[22]); // last gpt2_update | ||
size_t current_shard_idx = *((size_t*)&state_header[30]); // shard index | ||
size_t current_sample_idx = *((size_t*)&state_header[32]); // position in shard | ||
|
||
|
@@ -1237,17 +1255,21 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename | |
printf0("Error: Master weights requested, but not present in state file."); | ||
exit(EXIT_FAILURE); | ||
} | ||
|
||
model->init_state = false; // we just got the state from file, no need to do first-touch init | ||
assert(model->m_memory != nullptr); | ||
assert(model->v_memory != nullptr); | ||
file_to_device(model->m_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); | ||
file_to_device(model->v_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); | ||
if(model->use_master_weights) { | ||
assert(model->master_weights != nullptr); | ||
file_to_device(model->master_weights, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); | ||
// restore weights from the master weights using the RNG state before last weight update | ||
model->rng_state = model->rng_state_last_update; | ||
gpt2_update(model, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0, &multi_gpu_config, /* init_from_master_only*/ true); | ||
model->rng_state = *((unsigned long long*)&state_header[20]); // use final RNG state from checkpoint after this | ||
} | ||
|
||
model->init_state = false; // we just got the state from file, no need to do first-touch init | ||
|
||
// revive the DataLoader object and its state | ||
loader->should_shuffle = should_shuffle; | ||
if (should_shuffle == 1) { | ||
|
@@ -1532,7 +1554,7 @@ int main(int argc, char *argv[]) { | |
gpt2_init_common(&model); | ||
if (resuming == 1) { | ||
// if `-y 1` was set, then we are resuming from the latest checkpoint | ||
gpt2_build_from_checkpoint(&model, filename_buffer); | ||
gpt2_build_from_checkpoint(&model, filename_buffer, true); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. at this point |
||
} else if (ends_with_bin(load_filename)) { | ||
// otherwise, if this is a .bin file, we assume it's a model, let's init from it | ||
gpt2_build_from_checkpoint(&model, load_filename); | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused
check_identical
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops, that used to be for debug lots that I removed, thanks