From 270ca240f19cd47367e4b79cbb8add47d480bc1b Mon Sep 17 00:00:00 2001 From: ademeure Date: Sat, 20 Jul 2024 16:31:27 +0000 Subject: [PATCH] simplify a little bit --- llmc/adamw.cuh | 14 ++------------ train_gpt2.cu | 8 +++----- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/llmc/adamw.cuh b/llmc/adamw.cuh index 409c38033..6d64438ea 100644 --- a/llmc/adamw.cuh +++ b/llmc/adamw.cuh @@ -66,23 +66,13 @@ __global__ void params_from_master_kernel(Tp* params_memory, float* master_param ptrdiff_t w_stride, ptrdiff_t s_stride, unsigned int seed, bool check_identical) { size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= num_parameters) { return; } - - // adjust for layer offset - params_memory += blockIdx.y * w_stride; + params_memory += blockIdx.y * w_stride; // adjust for layer offset master_params_memory += blockIdx.y * s_stride; Tp rounded_param; float param = master_params_memory[idx]; stochastic_rounding(param, &rounded_param, seed); - - if (check_identical) { - // check if the rounded parameter is identical to the master parameter (debugging only) - if (params_memory[idx] != rounded_param) { - printf("Mismatch restoring master weights at index %llu (of %llu): %.20f != %.20f\n", - idx, num_parameters, (float)params_memory[idx], (float)rounded_param); - assert(false); - } - } + assert(!check_identical || params_memory[idx] == rounded_param); // for debugging only (needs non-master params to be loaded as well) params_memory[idx] = rounded_param; } diff --git a/train_gpt2.cu b/train_gpt2.cu index 945601e46..b1592c55b 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1088,11 +1088,9 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo } if (init_from_master_only) { - // this is only run when resuming training from a checkpoint with master weights - // it allows us to restart training with a different precision amongst other things - assert(master_ptr != NULL); - params_from_master(param_ptr, master_ptr, - shard.size, tensor.size, shard.size, num_layers, seed, main_stream); + // when resuming training from a checkpoint with master weights + init_from_master(param_ptr, master_ptr, + shard.size, tensor.size, shard.size, num_layers, seed, main_stream); } else { // ok finally call the kernel to update the weights with AdamW adamw_update(param_ptr, master_ptr, grad_ptr,