Skip to content

Commit

Permalink
simplified further (don't need non-functional error checking...)
Browse files Browse the repository at this point in the history
  • Loading branch information
ademeure committed Jul 20, 2024
1 parent 270ca24 commit 6387e66
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 14 deletions.
13 changes: 4 additions & 9 deletions llmc/adamw.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,13 @@ __global__ void adamw_kernel3(Tp* params_memory, float* master_params_memory, Tg
}

template <typename Tp>
__global__ void params_from_master_kernel(Tp* params_memory, float* master_params_memory, size_t num_parameters,
__global__ void init_from_master_kernel(Tp* params_memory, float* master_params_memory, size_t num_parameters,
ptrdiff_t w_stride, ptrdiff_t s_stride, unsigned int seed, bool check_identical) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_parameters) { return; }
params_memory += blockIdx.y * w_stride; // adjust for layer offset
master_params_memory += blockIdx.y * s_stride;

Tp rounded_param;
float param = master_params_memory[idx];
stochastic_rounding(param, &rounded_param, seed);
assert(!check_identical || params_memory[idx] == rounded_param); // for debugging only (needs non-master params to be loaded as well)
params_memory[idx] = rounded_param;
stochastic_rounding(master_params_memory[idx], &params_memory[idx], seed);
}

template <typename Tp, typename Tg>
Expand All @@ -93,11 +88,11 @@ void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memo
}

template <typename Tp>
void params_from_master(Tp* params_memory, float* master_params_memory, size_t num_parameters,
void init_from_master(Tp* params_memory, float* master_params_memory, size_t num_parameters,
ptrdiff_t w_stride, ptrdiff_t s_stride, int num_slices, unsigned int seed, cudaStream_t stream, bool check_identical=false) {
int block_size = 512; // must match block size of adamw_update so that RNG also matches
int num_blocks = CEIL_DIV(num_parameters, block_size);
params_from_master_kernel<<<dim3(num_blocks, num_slices), block_size, 0, stream>>>
init_from_master_kernel<<<dim3(num_blocks, num_slices), block_size, 0, stream>>>
(params_memory, master_params_memory, num_parameters, w_stride, s_stride, seed, check_identical);
cudaCheck(cudaGetLastError());
}
9 changes: 4 additions & 5 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,8 @@ float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) {
return grad_norm_cpu;
}

void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, float grad_scale, int t, MultiGpuConfig* multi_gpu_config, bool init_from_master_only=false) {
void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, float grad_scale, int t,
MultiGpuConfig* multi_gpu_config, bool init_from_master_only=false) {
// update the model parameters using the AdamW optimizer
// keep in mind that optimizer sharding (ZeRO-1) assigns different parameters to different GPUs
// so we may not be responsible for the entire parameter tensor
Expand Down Expand Up @@ -1088,9 +1089,8 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo
}

if (init_from_master_only) {
// when resuming training from a checkpoint with master weights
init_from_master(param_ptr, master_ptr,
shard.size, tensor.size, shard.size, num_layers, seed, main_stream);
// when resuming training from a checkpoint with master weights (allows changing precision)
init_from_master(param_ptr, master_ptr, shard.size, tensor.size, shard.size, num_layers, seed, main_stream);
} else {
// ok finally call the kernel to update the weights with AdamW
adamw_update(param_ptr, master_ptr, grad_ptr,
Expand All @@ -1099,7 +1099,6 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo
learning_rate,
beta1, beta2, t, eps, wd, grad_scale, seed, main_stream);
}
cudaCheck(cudaGetLastError());

if (multi_gpu_config->zero_stage == 1) {
#if MULTI_GPU
Expand Down

0 comments on commit 6387e66

Please sign in to comment.