Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore from master weights (& allow restoring from a checkpoint of different precision) #702

Merged
merged 6 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion llmc/adamw.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ __global__ void adamw_kernel3(Tp* params_memory, float* master_params_memory, Tg
);
}

template <typename Tp>
__global__ void init_from_master_kernel(Tp* params_memory, float* master_params_memory, size_t num_parameters,
ptrdiff_t w_stride, ptrdiff_t s_stride, unsigned int seed, bool check_identical) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused check_identical?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, that used to be for debug lots that I removed, thanks

size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_parameters) { return; }
params_memory += blockIdx.y * w_stride; // adjust for layer offset
master_params_memory += blockIdx.y * s_stride;
stochastic_rounding(master_params_memory[idx], &params_memory[idx], seed);
}

template <typename Tp, typename Tg>
void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,
ptrdiff_t w_stride, ptrdiff_t g_stride, ptrdiff_t s_stride, int num_slices, float learning_rate, float beta1, float beta2, int t, float eps, float weight_decay,
Expand All @@ -75,4 +85,14 @@ void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memo
learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay,
grad_scale, seed);
cudaCheck(cudaGetLastError());
}
}

template <typename Tp>
void init_from_master(Tp* params_memory, float* master_params_memory, size_t num_parameters,
ptrdiff_t w_stride, ptrdiff_t s_stride, int num_slices, unsigned int seed, cudaStream_t stream, bool check_identical=false) {
int block_size = 512; // must match block size of adamw_update so that RNG also matches
int num_blocks = CEIL_DIV(num_parameters, block_size);
init_from_master_kernel<<<dim3(num_blocks, num_slices), block_size, 0, stream>>>
(params_memory, master_params_memory, num_parameters, w_stride, s_stride, seed, check_identical);
cudaCheck(cudaGetLastError());
}
74 changes: 48 additions & 26 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,8 @@ typedef struct {
float* accumulated_mean_loss; // GPU buffer used to accumulate loss across micro-steps
float* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost
unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc.
int use_master_weights; // keep master weights copy in float for optim update? 0|1
unsigned long long rng_state_last_update; // RNG before last gpt2_update() to re-round identically from master weights
int use_master_weights; // keep master weights copy in float for optim update? 0|1
bool init_state; // set to true if master weights need to be initialized
int gelu_fusion; // fuse gelu via cuBLASLt (0=none, 1=forward, 2=forward+backward)
int recompute; // recompute gelu | layernorm forward during model backward? 0|1|2
Expand Down Expand Up @@ -438,7 +439,7 @@ void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) {
fcloseCheck(model_file);
}

void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool resuming=false) {

if (PRECISION_MODE == PRECISION_FP16) {
// TODO for later perhaps, would require us dynamically converting the
Expand All @@ -461,16 +462,20 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
fprintf(stderr, "---> HINT: try to re-run `python train_gpt2.py`\n");
exit(EXIT_FAILURE);
}
if (PRECISION_MODE == PRECISION_BF16 && version != 5) {
fprintf(stderr, "Precision is configured as BF16 but model at %s is not.\n", checkpoint_path);
fprintf(stderr, "---> HINT: are you sure you're loading a _bf16.bin file?\n");
exit(EXIT_FAILURE);
}
if (PRECISION_MODE == PRECISION_FP32 && version != 3) {
fprintf(stderr, "Precision is configured as FP32 but model at %s is not.\n", checkpoint_path);
fprintf(stderr, "---> HINT: to turn on FP32 you have to compile like: `make train_gpt2cu PRECISION=FP32`\n");
fprintf(stderr, "---> HINT: are you sure you're loading a .bin file without any _bf16 in the name?\n");
exit(EXIT_FAILURE);

// check if the precision mode matches the model (don't care if restoring from master weights!)
if (!resuming || !model->use_master_weights) {
if (PRECISION_MODE == PRECISION_BF16 && version != 5) {
fprintf(stderr, "Precision is configured as BF16 but model at %s is not.\n", checkpoint_path);
fprintf(stderr, "---> HINT: are you sure you're loading a _bf16.bin file?\n");
exit(EXIT_FAILURE);
}
if (PRECISION_MODE == PRECISION_FP32 && version != 3) {
fprintf(stderr, "Precision is configured as FP32 but model at %s is not.\n", checkpoint_path);
fprintf(stderr, "---> HINT: to turn on FP32 you have to compile like: `make train_gpt2cu PRECISION=FP32`\n");
fprintf(stderr, "---> HINT: are you sure you're loading a .bin file without any _bf16 in the name?\n");
exit(EXIT_FAILURE);
}
}

// read in hyperparameters
Expand All @@ -483,9 +488,11 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {

gpt2_allocate_weights(model);

// read in all the parameters from file and copy them to device
file_to_device(model->params_memory, model_file, model->num_parameters_bytes,
IO_BUF_SIZE, main_stream);
// read in all the parameters from file and copy them to device (if we need them)
// if we are restoring with master weights, ignore these weights and init from master instead
if (!resuming || !model->use_master_weights) {
file_to_device(model->params_memory, model_file, model->num_parameters_bytes, IO_BUF_SIZE, main_stream);
}
fcloseCheck(model_file);

// only return from this function once we are certain the params are ready on the GPU
Expand Down Expand Up @@ -1008,7 +1015,8 @@ float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) {
return grad_norm_cpu;
}

void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, float grad_scale, int t, MultiGpuConfig* multi_gpu_config) {
void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, float grad_scale, int t,
MultiGpuConfig* multi_gpu_config, bool init_from_master_only=false) {
// update the model parameters using the AdamW optimizer
// keep in mind that optimizer sharding (ZeRO-1) assigns different parameters to different GPUs
// so we may not be responsible for the entire parameter tensor
Expand All @@ -1028,6 +1036,10 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo
cudaCheck(cudaMemset(model->m_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(float)));
cudaCheck(cudaMemset(model->v_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(float)));
}

// save RNG state at this point so we can round from master weights identically when restoring from a checkpoint
model->rng_state_last_update = model->rng_state;

// AdamW update
// handle adamw for all the transformer blocks
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
Expand Down Expand Up @@ -1064,13 +1076,17 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo
cudaCheck(cudaGetLastError());
}

// ok finally call the kernel
adamw_update(param_ptr, master_ptr, grad_ptr,
m_ptr, v_ptr,
shard.size, tensor.size, tensor.size, shard.size, num_layers,
learning_rate,
beta1, beta2, t, eps, wd, grad_scale, seed, main_stream);
cudaCheck(cudaGetLastError());
if (init_from_master_only) {
// when resuming training from a checkpoint with master weights (allows changing precision)
init_from_master(param_ptr, master_ptr, shard.size, tensor.size, shard.size, num_layers, seed, main_stream);
} else {
// ok finally call the kernel to update the weights with AdamW
adamw_update(param_ptr, master_ptr, grad_ptr,
m_ptr, v_ptr,
shard.size, tensor.size, tensor.size, shard.size, num_layers,
learning_rate,
beta1, beta2, t, eps, wd, grad_scale, seed, main_stream);
}

if (multi_gpu_config->zero_stage == 1) {
#if MULTI_GPU
Expand Down Expand Up @@ -1189,6 +1205,7 @@ void save_state(const char* filename, int step, GPT2* model, DataLoader* loader)
state_header[10] = step; // step of the optimization
// model rng state, start at 20 to leave some padding
*((unsigned long long*)&state_header[20]) = model->rng_state; // random number generator state
*((unsigned long long*)&state_header[22]) = model->rng_state_last_update; // last gpt2_update
// dataloader state, start at 30 to leave some padding
*((size_t*)&state_header[30]) = loader->current_shard_idx; // shard of the dataset
*((size_t*)&state_header[32]) = loader->current_sample_idx; // position in shard
Expand Down Expand Up @@ -1225,6 +1242,7 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename
int should_shuffle = state_header[5]; // shuffle state of the dataloader
*step = state_header[10]; // step of the optimization
model->rng_state = *((unsigned long long*)&state_header[20]); // random number generator state
model->rng_state_last_update = *((unsigned long long*)&state_header[22]); // last gpt2_update
size_t current_shard_idx = *((size_t*)&state_header[30]); // shard index
size_t current_sample_idx = *((size_t*)&state_header[32]); // position in shard

Expand All @@ -1237,17 +1255,21 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename
printf0("Error: Master weights requested, but not present in state file.");
exit(EXIT_FAILURE);
}

model->init_state = false; // we just got the state from file, no need to do first-touch init
assert(model->m_memory != nullptr);
assert(model->v_memory != nullptr);
file_to_device(model->m_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream);
file_to_device(model->v_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream);
if(model->use_master_weights) {
assert(model->master_weights != nullptr);
file_to_device(model->master_weights, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream);
// restore weights from the master weights using the RNG state before last weight update
model->rng_state = model->rng_state_last_update;
gpt2_update(model, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0, &multi_gpu_config, /* init_from_master_only*/ true);
model->rng_state = *((unsigned long long*)&state_header[20]); // use final RNG state from checkpoint after this
}

model->init_state = false; // we just got the state from file, no need to do first-touch init

// revive the DataLoader object and its state
loader->should_shuffle = should_shuffle;
if (should_shuffle == 1) {
Expand Down Expand Up @@ -1532,7 +1554,7 @@ int main(int argc, char *argv[]) {
gpt2_init_common(&model);
if (resuming == 1) {
// if `-y 1` was set, then we are resuming from the latest checkpoint
gpt2_build_from_checkpoint(&model, filename_buffer);
gpt2_build_from_checkpoint(&model, filename_buffer, true);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at this point model.use_master_weights is not yet initialized with use_master_weights, that happens below, and is just the default (true). This variable is used inside gpt2_build_from_checkpoint and this is probably a bug?

} else if (ends_with_bin(load_filename)) {
// otherwise, if this is a .bin file, we assume it's a model, let's init from it
gpt2_build_from_checkpoint(&model, load_filename);
Expand Down