diff --git a/llmc/adamw.cuh b/llmc/adamw.cuh index 6d64438ea..e15e81a72 100644 --- a/llmc/adamw.cuh +++ b/llmc/adamw.cuh @@ -62,18 +62,13 @@ __global__ void adamw_kernel3(Tp* params_memory, float* master_params_memory, Tg } template -__global__ void params_from_master_kernel(Tp* params_memory, float* master_params_memory, size_t num_parameters, +__global__ void init_from_master_kernel(Tp* params_memory, float* master_params_memory, size_t num_parameters, ptrdiff_t w_stride, ptrdiff_t s_stride, unsigned int seed, bool check_identical) { size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= num_parameters) { return; } params_memory += blockIdx.y * w_stride; // adjust for layer offset master_params_memory += blockIdx.y * s_stride; - - Tp rounded_param; - float param = master_params_memory[idx]; - stochastic_rounding(param, &rounded_param, seed); - assert(!check_identical || params_memory[idx] == rounded_param); // for debugging only (needs non-master params to be loaded as well) - params_memory[idx] = rounded_param; + stochastic_rounding(master_params_memory[idx], ¶ms_memory[idx], seed); } template @@ -93,11 +88,11 @@ void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memo } template -void params_from_master(Tp* params_memory, float* master_params_memory, size_t num_parameters, +void init_from_master(Tp* params_memory, float* master_params_memory, size_t num_parameters, ptrdiff_t w_stride, ptrdiff_t s_stride, int num_slices, unsigned int seed, cudaStream_t stream, bool check_identical=false) { int block_size = 512; // must match block size of adamw_update so that RNG also matches int num_blocks = CEIL_DIV(num_parameters, block_size); - params_from_master_kernel<<>> + init_from_master_kernel<<>> (params_memory, master_params_memory, num_parameters, w_stride, s_stride, seed, check_identical); cudaCheck(cudaGetLastError()); } diff --git a/train_gpt2.cu b/train_gpt2.cu index b1592c55b..f23d6d2f3 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1020,7 +1020,8 @@ float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { return grad_norm_cpu; } -void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, float grad_scale, int t, MultiGpuConfig* multi_gpu_config, bool init_from_master_only=false) { +void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, float grad_scale, int t, + MultiGpuConfig* multi_gpu_config, bool init_from_master_only=false) { // update the model parameters using the AdamW optimizer // keep in mind that optimizer sharding (ZeRO-1) assigns different parameters to different GPUs // so we may not be responsible for the entire parameter tensor @@ -1088,9 +1089,8 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo } if (init_from_master_only) { - // when resuming training from a checkpoint with master weights - init_from_master(param_ptr, master_ptr, - shard.size, tensor.size, shard.size, num_layers, seed, main_stream); + // when resuming training from a checkpoint with master weights (allows changing precision) + init_from_master(param_ptr, master_ptr, shard.size, tensor.size, shard.size, num_layers, seed, main_stream); } else { // ok finally call the kernel to update the weights with AdamW adamw_update(param_ptr, master_ptr, grad_ptr, @@ -1099,7 +1099,6 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo learning_rate, beta1, beta2, t, eps, wd, grad_scale, seed, main_stream); } - cudaCheck(cudaGetLastError()); if (multi_gpu_config->zero_stage == 1) { #if MULTI_GPU