diff --git a/llmc/cuda_utils.cuh b/llmc/cuda_utils.cuh
index 0ce728ee1..70ce30290 100644
--- a/llmc/cuda_utils.cuh
+++ b/llmc/cuda_utils.cuh
@@ -136,7 +136,7 @@ __global__ void copy_and_cast_kernel(Td* dst, const Ts* src, size_t n, ptrdiff_t
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// need to try grid stride looping for more perf later
if (idx < n) {
- dst[idx + stride_dst * blockIdx.y] = cast_value
(src[idx + stride_src * blockIdx.y]);
+ dst[idx + stride_dst * blockIdx.y] = cast_value | (src[idx + stride_src * blockIdx.y]);
}
}
@@ -260,4 +260,21 @@ __device__ __forceinline__ void stochastic_rounding(float in, float *out, unsign
*out = in; // dummy function for when floatX is float (FP32 mode)
}
+// Add two (potentially low-precision) vectors of size `n` together using stochastic rounding
+template
+__global__ void vector_add(T* dst, const T* src, size_t n, unsigned seed) {
+ using t128 = Packed128;
+ assert(n % t128::size == 0);
+ ptrdiff_t idx = ((ptrdiff_t)blockIdx.x * blockDim.x + threadIdx.x) * t128::size;
+ if (idx < n) {
+ t128 src_v = load128cs(src + idx);
+ t128 dst_v = load128cs(dst + idx);
+ for(int k = 0; k < t128::size; ++k) {
+ float sum = (float)dst_v[k] + (float)src_v[k];
+ stochastic_rounding(sum, &dst_v[k], seed + idx);
+ }
+ store128cs(dst + idx, dst_v);
+ }
+}
+
#endif
\ No newline at end of file
diff --git a/llmc/zero.cuh b/llmc/zero.cuh
index e6c5b6e7c..d84c3a210 100644
--- a/llmc/zero.cuh
+++ b/llmc/zero.cuh
@@ -506,27 +506,40 @@ ShardInfo multi_gpu_get_shard_offset(size_t elements, const MultiGpuConfig* conf
}
}
-// Block NCCL stream until computations on compute_stream are done, then aggregate multiple pointers in an NCCL group.
+void nccl_wait_on_compute(MultiGpuConfig* config, cudaStream_t compute_stream) {
+ // mark an event on the compute stream, and immediately wait on this in the nccl stream
+ // this means that the nccl stream won't start executing before all compute kernels that
+ // have been submitted before this point have finished.
+ // by using an event instead of cudaSyncStream, we avoid having to synchronize the host, and
+ // can enqueue new work to the GPU right away.
+#ifdef MULTI_GPU
+ cudaCheck(cudaEventRecord(config->compute_nccl_sync, compute_stream));
+ cudaCheck(cudaStreamWaitEvent(config->nccl_stream, config->compute_nccl_sync));
+#endif
+}
+
+void compute_wait_on_nccl(MultiGpuConfig* config, cudaStream_t compute_stream) {
+ // mark an event on the nccl stream, and immediately wait on this in the compute stream
+#ifdef MULTI_GPU
+ cudaCheck(cudaEventRecord(config->compute_nccl_sync, config->nccl_stream));
+ cudaCheck(cudaStreamWaitEvent(compute_stream, config->compute_nccl_sync));
+#endif
+}
+
+// Aggregate multiple pointers in an NCCL group.
// This can work either as an all-reduce (i.e., no ZeRo), or a reduce-scatter (ZeRO 1).
// The awkward `(&pointers)[N]` syntax ensures we are capturing the parameters as sized arrays, so that it becomes impossible
// to call this function if pointers and pointers_sizes do not match.
template
void multi_gpu_async_reduce_gradient(
floatX* const (&pointers)[N], const size_t (&pointers_sizes)[N],
- MultiGpuConfig* config, cudaStream_t compute_stream) {
+ MultiGpuConfig* config) {
if (config->num_processes == 1) {
return; // no multi-GPU, just exit.
}
#ifdef MULTI_GPU
NVTX_RANGE_FN();
- // mark an event on the compute stream, and immediately wait on this in the nccl stream
- // this means that the nccl stream won't start executing before all compute kernels that
- // have been submitted before this point have finished.
- // by using an event instead of cudaSyncStream, we avoid having to synchronize the host, and
- // can enqueue new work to the GPU right away.
- cudaCheck(cudaEventRecord(config->compute_nccl_sync, compute_stream));
- cudaCheck(cudaStreamWaitEvent(config->nccl_stream, config->compute_nccl_sync));
ncclCheck(ncclGroupStart()); // NCCL group: aggregate all pointers in a single NCCL GPU kernel.
for (int i = 0; i < N; ++i) {
if(config->zero_stage == 0) {
@@ -536,7 +549,7 @@ void multi_gpu_async_reduce_gradient(
ncclFloatX, ncclAvg,
config->nccl_comm, config->nccl_stream
));
- } else if(config->zero_stage == 1) {
+ } else if(config->zero_stage == 1 || config->zero_stage == 2) {
assert(pointers_sizes[i] % config->num_processes == 0);
size_t shard_size = pointers_sizes[i] / config->num_processes;
ptrdiff_t shard_offset = (ptrdiff_t)shard_size * config->process_rank;
@@ -562,18 +575,18 @@ void set_zero_configs(MultiGpuConfig* config, int zero_stage, size_t total_param
if (zero_stage == 0) {
printf0("| Zero Optimization is disabled |\n");
}
- else if (zero_stage == 1) {
+ else if (zero_stage == 1 || zero_stage == 2) {
if (total_parameters % config->num_processes != 0) {
printf0("| Zero Optimization is disabled, Can't equally partition parameters |\n");
config->zero_stage = 0;
}
else {
- config->zero_stage = 1;
+ config->zero_stage = zero_stage;
config->shard_num_parameters = total_parameters / config->num_processes;
}
}
else{
- printf0("| Disabling Zero Optimization, Zero Stage2 and Stage3 are not yet supported |\n");
+ printf0("| Disabling Zero Optimization, Zero Stage3 is not yet supported |\n");
config->zero_stage = 0;
}
}
@@ -593,5 +606,20 @@ float multi_gpu_cpu_float_sum(float value, MultiGpuConfig* config) {
#endif
}
+template
+void zero2_accumulate_grad(floatX* const (&dst)[N], floatX* const (&src)[N], const size_t (&nelem)[N], int layer, unsigned seed, MultiGpuConfig* config) {
+#ifdef MULTI_GPU
+ cudaStream_t stream = config->nccl_stream;
+ for(int i = 0; i < N; ++i) {
+ size_t n = nelem[i] / multi_gpu_config.num_processes;
+ vector_add<<>>(dst[i] + layer * n,
+ src[i] + multi_gpu_config.process_rank * n,
+ n, seed + i);
+ cudaCheck(cudaGetLastError());
+ cudaCheck(cudaMemsetAsync(src[i], 0, nelem[i] * sizeof(floatX), stream));
+ }
+#endif
+}
+
#endif
diff --git a/train_gpt2.cu b/train_gpt2.cu
index 8f110911a..916d7eea3 100644
--- a/train_gpt2.cu
+++ b/train_gpt2.cu
@@ -293,7 +293,11 @@ typedef struct {
size_t num_parameters_bytes;
// gradients of the weights
ParameterTensors grads;
+ size_t grads_bytes;
+ ParameterTensors grad_shards; // ZeRO-2 gradient shards
+ size_t grad_shards_bytes;
void* grads_memory;
+ void* grad_shards_memory;
// buffers for the AdamW optimizer
float* m_memory;
float* v_memory;
@@ -335,6 +339,7 @@ void gpt2_init_common(GPT2 *model) {
model->params_memory = NULL;
// memory lazily initialized in backward()
model->grads_memory = NULL;
+ model->grad_shards_memory = NULL;
model->workload_indices = NULL; // on cpu, for encoder_backward
model->bucket_info = NULL; // on cpu, for encoder_backward
// memory lazily initialized in update()
@@ -364,9 +369,44 @@ void gpt2_allocate_weights(GPT2 *model) {
}
void gpt2_allocate_state(GPT2 *model, int B, int T) {
- printf0("allocating %d MiB for parameter gradients\n", (int)round(model->num_parameters * sizeof(floatX) / (1024 * 1024)));
assert(model->grads_memory == nullptr);
- model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_elements, model->param_sizeof);
+
+ if(multi_gpu_config.zero_stage == 2) {
+ // Allocate parameter buffers for the current layers active "wave" of computation
+ size_t param_elements[NUM_PARAMETER_TENSORS];
+ size_t param_sizeof[NUM_PARAMETER_TENSORS];
+ GPT2Config wave_config = model->config;
+ // to prevent having to wait for comms to complete, we need to double-buffer gradients, so we need to
+ // allocate as if we had a two-layer network
+ wave_config.num_layers = 2;
+ fill_in_parameter_sizes(param_elements, param_sizeof, wave_config);
+ size_t alloc_bytes = 0;
+ for(int i = 0; i < NUM_PARAMETER_TENSORS; ++i) {
+ alloc_bytes += param_sizeof[i] * param_elements[i];
+ }
+ printf0("allocating %d MiB for ZeRO-2 active gradients\n",
+ (int) round(alloc_bytes / (1024 * 1024)));
+ model->grads_memory = malloc_and_point_parameters(&model->grads, param_elements, param_sizeof);
+ model->grads_bytes = alloc_bytes;
+
+ // next, allocate memory for the local gradient shards
+ alloc_bytes = 0;
+ fill_in_parameter_sizes(param_elements, param_sizeof, model->config);
+ for(int i = 0; i < NUM_PARAMETER_TENSORS; ++i) {
+ param_elements[i] /= multi_gpu_config.num_processes;
+ alloc_bytes += param_sizeof[i] * param_elements[i];
+ }
+ printf0("allocating %d MiB for ZeRO-2 gradient shards\n",
+ (int) round(alloc_bytes / (1024 * 1024)));
+ model->grad_shards_memory = malloc_and_point_parameters(&model->grad_shards, param_elements, param_sizeof);
+ model->grad_shards_bytes = alloc_bytes;
+ } else {
+ printf0("allocating %d MiB for parameter gradients\n",
+ (int) round(model->num_parameters * sizeof(floatX) / (1024 * 1024)));
+ model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_elements,
+ model->param_sizeof);
+ model->grads_bytes = model->num_parameters * sizeof(floatX);
+ }
// record the current B,T as well
model->batch_size = B;
@@ -774,7 +814,10 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
// 1) the losses accumulate += into acts.losses, reset here
// 2) the gradients accumulate += into grads_memory, reset here
cudaCheck(cudaMemsetAsync(model->acts.losses, 0, model->batch_size * model->seq_len * sizeof(float), main_stream));
- cudaCheck(cudaMemsetAsync(model->grads_memory, 0, model->num_parameters * sizeof(floatX), main_stream));
+ cudaCheck(cudaMemsetAsync(model->grads_memory, 0, model->grads_bytes, main_stream));
+ if(model->grad_shards_memory != NULL) {
+ cudaCheck(cudaMemset(model->grad_shards_memory, 0, model->grad_shards_bytes));
+ }
}
// convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow
@@ -837,18 +880,19 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
floatX* l_fcw = params.fcw + l * 4*C * C;
floatX* l_fcprojw = params.fcprojw + l * C * 4*C;
// get the pointers of the gradients of the weights for this layer
- floatX* dl_ln1w = grads.ln1w + l * C;
- floatX* dl_ln1b = grads.ln1b + l * C;
- floatX* dl_qkvw = grads.qkvw + l * 3*C * C;
- floatX* dl_qkvb = grads.qkvb + l * 3*C;
- floatX* dl_attprojw = grads.attprojw + l * C * C;
- floatX* dl_attprojb = grads.attprojb + l * C;
- floatX* dl_ln2w = grads.ln2w + l * C;
- floatX* dl_ln2b = grads.ln2b + l * C;
- floatX* dl_fcw = grads.fcw + l * 4*C * C;
- floatX* dl_fcb = grads.fcb + l * 4*C;
- floatX* dl_fcprojw = grads.fcprojw + l * C * 4*C;
- floatX* dl_fcprojb = grads.fcprojb + l * C;
+ ptrdiff_t grad_l = multi_gpu_config.zero_stage == 2 ? l % 2 : l;
+ floatX* dl_ln1w = grads.ln1w + grad_l * C;
+ floatX* dl_ln1b = grads.ln1b + grad_l * C;
+ floatX* dl_qkvw = grads.qkvw + grad_l * 3*C * C;
+ floatX* dl_qkvb = grads.qkvb + grad_l * 3*C;
+ floatX* dl_attprojw = grads.attprojw + grad_l * C * C;
+ floatX* dl_attprojb = grads.attprojb + grad_l * C;
+ floatX* dl_ln2w = grads.ln2w + grad_l * C;
+ floatX* dl_ln2b = grads.ln2b + grad_l * C;
+ floatX* dl_fcw = grads.fcw + grad_l * 4*C * C;
+ floatX* dl_fcb = grads.fcb + grad_l * 4*C;
+ floatX* dl_fcprojw = grads.fcprojw + grad_l * C * 4*C;
+ floatX* dl_fcprojb = grads.fcprojb + grad_l * C;
// get the pointers of the activations for this layer
floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf;
float* l_ln1_mean = acts.ln1_mean + l * B * T;
@@ -902,7 +946,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
layernorm_backward(dresidual, dl_ln1w, dl_ln1b, scratchF, dl_btc, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C, main_stream);
// Accumulate gradients from this layer in a background stream.
- if(last_step) {
+ if(last_step || multi_gpu_config.zero_stage == 2) {
floatX* const pointers[] = {
dl_ln1w, dl_ln1b,
dl_qkvw, dl_qkvb,
@@ -919,14 +963,46 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
4 * C * C, 4 * C,
C * 4 * C, C
};
- multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream);
+ if(multi_gpu_config.zero_stage == 2) {
+ // wait for previous layer's transfer to be completed, so we don't get any race conditions
+ // where main stream is writing to a buffer that nccl stream is reading
+ // this will result in a structure as follows:
+ // main stream: calculate grads of layer n-1 and write to buffer 1 || NCCL stream: waiting for buffer 1 data
+ // main stream: calculate grads of layer n-2 and write to buffer 2 || NCCL stream: transmitting buffer 1 data
+ // main stream: potentially waiting (sync) while nccl stream finishes transmitting buffer 1 data
+ // main stream: calculate grads of layer n-3 and write to buffer 1 || NCCL stream: transmitting buffer 2 data
+ // ...
+ compute_wait_on_nccl(&multi_gpu_config, main_stream);
+ }
+ nccl_wait_on_compute(&multi_gpu_config, main_stream);
+ multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config);
+
+ if(multi_gpu_config.zero_stage == 2) {
+ // and scatter-add it to the local shard buffers
+ // why can't we just scatter-add into the shard buffer directly?
+ // because that would overwrite any existing values there, which
+ // we need for grad accumulation
+
+ ParameterTensors g = model->grad_shards;
+ floatX* const dst_ptr[] = {
+ g.ln1w, g.ln1b,
+ g.qkvw, g.qkvb,
+ g.attprojw, g.attprojb,
+ g.ln2w, g.ln2b,
+ g.fcw, g.fcb,
+ g.fcprojw, g.fcprojb
+ };
+
+ unsigned int seed = random_u32(&model->rng_state);
+ zero2_accumulate_grad(dst_ptr, pointers, nelem, l, seed, &multi_gpu_config);
+ }
}
}
encoder_backward(grads.wte, grads.wpe, scratchX, model->workload_indices, model->bucket_info,
dresidual, model->inputs, inputs, B, T, C, random_u32(&model->rng_state), main_stream);
// Aggregate all gradients that are not part of the transformer blocks
- if(last_step) {
+ if(last_step || multi_gpu_config.zero_stage == 2) {
// reduce all the losses within the current GPU (across all microsteps)
global_sum_deterministic(model->accumulated_mean_loss, acts.losses, B*T, main_stream);
// reduce loss across GPUs to a single, final float across all microsteps and GPUs
@@ -937,7 +1013,18 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
// reduce the gradients for non-transformer block parameters
floatX* const pointers[] = {grads.wte, grads.wpe, grads.lnfw, grads.lnfb};
const size_t nelem[] = {Vp * C, T * C, C, C};
- multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream);
+ nccl_wait_on_compute(&multi_gpu_config, main_stream);
+ multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config);
+ if(multi_gpu_config.zero_stage == 2) {
+ ParameterTensors g = model->grad_shards;
+ floatX* const dst_ptr[] = {
+ g.wte, g.wpe,
+ g.lnfw, g.lnfb,
+ };
+ unsigned int seed = random_u32(&model->rng_state);
+ // runs on nccl_stream, so automatically sequential w.r.t. the previous all-reduce
+ zero2_accumulate_grad(dst_ptr, pointers, nelem, 0, seed, &multi_gpu_config);
+ }
}
cudaCheck(cudaDeviceSynchronize());
@@ -997,12 +1084,21 @@ float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) {
// further sum the (partial) squared norm across all GPUs
ncclCheck(ncclAllReduce(grad_norm_squared, grad_norm_squared, sizeof(float), ncclFloat, ncclSum, multi_gpu_config->nccl_comm, main_stream));
#endif
- } else {
+ } else if(multi_gpu_config->zero_stage == 0) {
// in regular DDP, backward has averaged the gradients across all GPUs
// so each GPU can compute the squared norm over the whole grad vector, with no added comms needed
global_norm_squared(grad_norm_squared, grads_memory, model->num_parameters, 0, 1, max_num_block_sums, true, main_stream);
global_sum_deterministic(grad_norm_squared, grad_norm_squared, max_num_block_sums, main_stream);
+ } else if(multi_gpu_config->zero_stage == 2) {
+#if MULTI_GPU
+ // our gradient shards are contiguous in memory, so taking a (partial) global norm is easy
+ global_norm_squared(grad_norm_squared, (floatX*)model->grad_shards_memory, model->num_parameters / multi_gpu_config->num_processes,
+ 0, 1, max_num_block_sums, true, main_stream);
+ global_sum_deterministic(grad_norm_squared, grad_norm_squared, max_num_block_sums, main_stream);
+ ncclCheck(ncclAllReduce(grad_norm_squared, grad_norm_squared, sizeof(float), ncclFloat, ncclSum, multi_gpu_config->nccl_comm, main_stream));
+#endif
}
+
cudaCheck(cudaMemcpy(&grad_norm_squared_cpu, grad_norm_squared, sizeof(float), cudaMemcpyDeviceToHost));
float grad_norm_cpu = sqrtf(grad_norm_squared_cpu);
return grad_norm_cpu;
@@ -1051,6 +1147,9 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo
float wd = (i == 0 || i == 1 || i == 4 || i == 6 || i == 10 || i == 12) ? weight_decay : 0.0f;
floatX* param_ptr = (floatX*)model->params_memory + local_offset_full;
floatX* grad_ptr = (floatX*)model->grads_memory + local_offset_full;
+ if(multi_gpu_config->zero_stage == 2) {
+ grad_ptr = (floatX*)model->grad_shards_memory + local_offset_partial;
+ }
ptrdiff_t opt_state_offset = multi_gpu_config->zero_stage < 1 ? local_offset_full : local_offset_partial;
float* m_ptr = model->m_memory + opt_state_offset;
@@ -1067,12 +1166,15 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo
// ok finally call the kernel
adamw_update(param_ptr, master_ptr, grad_ptr,
m_ptr, v_ptr,
- shard.size, tensor.size, tensor.size, shard.size, num_layers,
+ shard.size, tensor.size,
+ multi_gpu_config->zero_stage == 2 ? shard.size : tensor.size,
+ shard.size,
+ num_layers,
learning_rate,
beta1, beta2, t, eps, wd, grad_scale, seed, main_stream);
cudaCheck(cudaGetLastError());
- if (multi_gpu_config->zero_stage == 1) {
+ if (multi_gpu_config->zero_stage != 0) {
#if MULTI_GPU
ncclCheck(ncclGroupStart());
for(int l = 0; l < num_layers; ++l) {
@@ -1122,6 +1224,7 @@ float gpt2_estimate_mfu(GPT2 *model, int num_tokens, float dt) {
void gpt2_free(GPT2 *model) {
cudaFreeCheck(&model->params_memory);
cudaFreeCheck(&model->grads_memory);
+ cudaFreeCheck(&model->grad_shards_memory);
cudaFreeCheck(&model->m_memory);
cudaFreeCheck(&model->v_memory);
cudaFreeCheck(&model->master_weights);
|