Skip to content

Commit

Permalink
simplify a little bit
Browse files Browse the repository at this point in the history
  • Loading branch information
ademeure committed Jul 20, 2024
1 parent 12c423e commit 270ca24
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 17 deletions.
14 changes: 2 additions & 12 deletions llmc/adamw.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -66,23 +66,13 @@ __global__ void params_from_master_kernel(Tp* params_memory, float* master_param
ptrdiff_t w_stride, ptrdiff_t s_stride, unsigned int seed, bool check_identical) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_parameters) { return; }

// adjust for layer offset
params_memory += blockIdx.y * w_stride;
params_memory += blockIdx.y * w_stride; // adjust for layer offset
master_params_memory += blockIdx.y * s_stride;

Tp rounded_param;
float param = master_params_memory[idx];
stochastic_rounding(param, &rounded_param, seed);

if (check_identical) {
// check if the rounded parameter is identical to the master parameter (debugging only)
if (params_memory[idx] != rounded_param) {
printf("Mismatch restoring master weights at index %llu (of %llu): %.20f != %.20f\n",
idx, num_parameters, (float)params_memory[idx], (float)rounded_param);
assert(false);
}
}
assert(!check_identical || params_memory[idx] == rounded_param); // for debugging only (needs non-master params to be loaded as well)
params_memory[idx] = rounded_param;
}

Expand Down
8 changes: 3 additions & 5 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1088,11 +1088,9 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo
}

if (init_from_master_only) {
// this is only run when resuming training from a checkpoint with master weights
// it allows us to restart training with a different precision amongst other things
assert(master_ptr != NULL);
params_from_master(param_ptr, master_ptr,
shard.size, tensor.size, shard.size, num_layers, seed, main_stream);
// when resuming training from a checkpoint with master weights
init_from_master(param_ptr, master_ptr,
shard.size, tensor.size, shard.size, num_layers, seed, main_stream);
} else {
// ok finally call the kernel to update the weights with AdamW
adamw_update(param_ptr, master_ptr, grad_ptr,
Expand Down

0 comments on commit 270ca24

Please sign in to comment.