Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move set_zero_configs into zero.cuh #674

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 76 additions & 31 deletions llmc/zero.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Utilities for ZeRO sharding
#endif
#endif

// defines: fcloseCheck, fwriteCheck, scloseCheck, sclosesocketCheckCheck
// defines: fcloseCheck, fwriteCheck, scloseCheck, sclosesocketCheck
#include "utils.h"

// ----------------------------------------------------------------------------
Expand Down Expand Up @@ -78,6 +78,10 @@ typedef struct {
#endif
} MultiGpuConfig;

// one global variable to hold the multi-GPU configuration for this process
// inline, so we can include this header multiple times without getting multiple definitions
inline MultiGpuConfig multi_gpu_config;

#ifdef MULTI_GPU

#ifdef _WIN32
Expand Down Expand Up @@ -461,22 +465,22 @@ MultiGpuConfig multi_gpu_config_init(int num_processes, int process_rank, int gp
#endif
}

void multi_gpu_config_free(MultiGpuConfig* multi_gpu_config) {
void multi_gpu_config_free(MultiGpuConfig* config) {
#ifdef MULTI_GPU
ncclCheck(ncclCommDestroy(multi_gpu_config->nccl_comm));
cudaCheck(cudaStreamDestroy(multi_gpu_config->nccl_stream));
cudaCheck(cudaEventDestroy(multi_gpu_config->compute_nccl_sync));
cudaCheck(cudaFree(multi_gpu_config->unified_buffer));
ncclCheck(ncclCommDestroy(config->nccl_comm));
cudaCheck(cudaStreamDestroy(config->nccl_stream));
cudaCheck(cudaEventDestroy(config->compute_nccl_sync));
cudaCheck(cudaFree(config->unified_buffer));
#ifdef USE_MPI
mpiCheck(MPI_Finalize());
#endif
#endif
}

void multi_gpu_barrier(const MultiGpuConfig* multi_gpu_config) {
void multi_gpu_barrier(const MultiGpuConfig* config) {
#ifdef MULTI_GPU
if (multi_gpu_config->num_processes > 1) {
ncclCheck(ncclAllReduce(multi_gpu_config->unified_buffer, multi_gpu_config->unified_buffer, sizeof(float), ncclFloat, ncclSum, multi_gpu_config->nccl_comm, multi_gpu_config->nccl_stream));
if (config->num_processes > 1) {
ncclCheck(ncclAllReduce(config->unified_buffer, config->unified_buffer, sizeof(float), ncclFloat, ncclSum, config->nccl_comm, config->nccl_stream));
}
cudaCheck(cudaDeviceSynchronize());
#endif
Expand All @@ -489,14 +493,14 @@ typedef struct {
} ShardInfo;

// Get info about sharding for a tensor of elements many numbers
ShardInfo multi_gpu_get_shard_offset(size_t elements, const MultiGpuConfig* multi_gpu_config, int shard_at_stage) {
const int nproc = multi_gpu_config->num_processes;
if(multi_gpu_config->zero_stage >= shard_at_stage) {
ShardInfo multi_gpu_get_shard_offset(size_t elements, const MultiGpuConfig* config, int shard_at_stage) {
const int nproc = config->num_processes;
if(config->zero_stage >= shard_at_stage) {
if (elements % nproc != 0) {
fprintf(stderr, "Number of elements %zu must be a multiple of the number of processes %d\n", elements, nproc);
exit(EXIT_FAILURE);
}
return {(ptrdiff_t) (multi_gpu_config->process_rank * (elements / nproc)), elements / nproc};
return {(ptrdiff_t) (config->process_rank * (elements / nproc)), elements / nproc};
} else {
return {0, elements};
}
Expand All @@ -508,9 +512,9 @@ ShardInfo multi_gpu_get_shard_offset(size_t elements, const MultiGpuConfig* mult
// to call this function if pointers and pointers_sizes do not match.
template<int N>
void multi_gpu_async_reduce_gradient(
floatX* const (&pointers)[N], const size_t (&pointers_sizes)[N],
MultiGpuConfig* multi_gpu_config, cudaStream_t compute_stream) {
if (multi_gpu_config->num_processes == 1) {
floatX* const (&pointers)[N], const size_t (&pointers_sizes)[N],
MultiGpuConfig* config, cudaStream_t compute_stream) {
if (config->num_processes == 1) {
return; // no multi-GPU, just exit.
}

Expand All @@ -521,32 +525,73 @@ void multi_gpu_async_reduce_gradient(
// have been submitted before this point have finished.
// by using an event instead of cudaSyncStream, we avoid having to synchronize the host, and
// can enqueue new work to the GPU right away.
cudaCheck(cudaEventRecord(multi_gpu_config->compute_nccl_sync, compute_stream));
cudaCheck(cudaStreamWaitEvent(multi_gpu_config->nccl_stream, multi_gpu_config->compute_nccl_sync));
cudaCheck(cudaEventRecord(config->compute_nccl_sync, compute_stream));
cudaCheck(cudaStreamWaitEvent(config->nccl_stream, config->compute_nccl_sync));
ncclCheck(ncclGroupStart()); // NCCL group: aggregate all pointers in a single NCCL GPU kernel.
for (int i = 0; i < N; ++i) {
if(multi_gpu_config->zero_stage == 0) {
if(config->zero_stage == 0) {
ncclCheck(ncclAllReduce(
pointers[i], pointers[i],
pointers_sizes[i],
ncclFloatX, ncclAvg,
multi_gpu_config->nccl_comm, multi_gpu_config->nccl_stream
pointers[i], pointers[i],
pointers_sizes[i],
ncclFloatX, ncclAvg,
config->nccl_comm, config->nccl_stream
));
} else if(multi_gpu_config->zero_stage == 1) {
assert(pointers_sizes[i] % multi_gpu_config->num_processes == 0);
size_t shard_size = pointers_sizes[i] / multi_gpu_config->num_processes;
ptrdiff_t shard_offset = (ptrdiff_t)shard_size * multi_gpu_config->process_rank;
} else if(config->zero_stage == 1) {
assert(pointers_sizes[i] % config->num_processes == 0);
size_t shard_size = pointers_sizes[i] / config->num_processes;
ptrdiff_t shard_offset = (ptrdiff_t)shard_size * config->process_rank;
ncclCheck(ncclReduceScatter(
pointers[i], pointers[i] + shard_offset,
shard_size,
ncclFloatX, ncclAvg,
multi_gpu_config->nccl_comm, multi_gpu_config->nccl_stream
pointers[i], pointers[i] + shard_offset,
shard_size,
ncclFloatX, ncclAvg,
config->nccl_comm, config->nccl_stream
));
}
}
ncclCheck(ncclGroupEnd());
#endif
}

// convenience macro that only prints if the rank of process is zero
#define printf0(...) if (::multi_gpu_config.process_rank == 0) { printf(__VA_ARGS__); }

void set_zero_configs(MultiGpuConfig* config, int zero_stage, size_t total_parameters) {
config->zero_stage = 0;
config->shard_num_parameters = total_parameters;
// Check the Zero Stage and define sharding parameters
if (zero_stage == 0) {
printf0("| Zero Optimization is disabled |\n");
}
else if (zero_stage == 1) {
if (total_parameters % config->num_processes != 0) {
printf0("| Zero Optimization is disabled, Can't equally partition parameters |\n");
config->zero_stage = 0;
}
else {
config->zero_stage = 1;
config->shard_num_parameters = total_parameters / config->num_processes;
}
}
else{
printf0("| Disabling Zero Optimization, Zero Stage2 and Stage3 are not yet supported |\n");
config->zero_stage = 0;
}
}

// Compute sum of a single CPU value across all GPU processes. No-op when multi-GPU is disabled.
float multi_gpu_cpu_float_sum(float value, MultiGpuConfig* config) {
#ifdef MULTI_GPU
if (config->num_processes == 1) return value;

float* unified_buffer = config->unified_buffer;
*unified_buffer = value;
ncclCheck(ncclAllReduce(unified_buffer, unified_buffer, sizeof(float), ncclFloat, ncclSum, config->nccl_comm, config->nccl_stream));
cudaCheck(cudaDeviceSynchronize());
return *unified_buffer;
#else
return value;
#endif
}

#endif

50 changes: 0 additions & 50 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,44 +73,9 @@ char filename_buffer[512];
// global vars containing information about the GPU this process is running on
cudaDeviceProp deviceProp; // fills in common_start()
cudaStream_t main_stream;
// one global variable to hold the multi-GPU configuration for this process
MultiGpuConfig multi_gpu_config;
// buffer size to use for device <-> disk io
constexpr const size_t IO_BUF_SIZE = 32 * 1024 * 1024;

// convenience function that only prints if the rank of process is zero
void printf0(const char *format, ...) {
if (multi_gpu_config.process_rank == 0) {
va_list args;
va_start(args, format);
vprintf(format, args);
va_end(args);
}
}

void set_zero_configs(MultiGpuConfig* multi_gpu_config, int zero_stage, size_t total_parameters) {
multi_gpu_config->zero_stage = 0;
multi_gpu_config->shard_num_parameters = total_parameters;
// Check the Zero Stage and define sharding parameters
if (zero_stage == 0) {
printf0("| Zero Optimization is disabled |\n");
}
else if (zero_stage == 1) {
if (total_parameters % multi_gpu_config->num_processes != 0) {
printf0("| Zero Optimization is disabled, Can't equally partition parameters |\n");
multi_gpu_config->zero_stage = 0;
}
else {
multi_gpu_config->zero_stage = 1;
multi_gpu_config->shard_num_parameters = total_parameters / multi_gpu_config->num_processes;
}
}
else{
printf0("| Disabling Zero Optimization, Zero Stage2 and Stage3 are not yet supported |\n");
multi_gpu_config->zero_stage = 0;
}
}

// ----------------------------------------------------------------------------
// GPT-2 model definition

Expand Down Expand Up @@ -938,21 +903,6 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
}
}

// Compute sum of a single CPU value across all GPU processes. No-op when multi-GPU is disabled.
float multi_gpu_cpu_float_sum(float value, MultiGpuConfig* multi_gpu_config) {
#ifdef MULTI_GPU
if (multi_gpu_config->num_processes == 1) return value;

float* unified_buffer = multi_gpu_config->unified_buffer;
*unified_buffer = value;
ncclCheck(ncclAllReduce(unified_buffer, unified_buffer, sizeof(float), ncclFloat, ncclSum, multi_gpu_config->nccl_comm, multi_gpu_config->nccl_stream));
cudaCheck(cudaDeviceSynchronize());
return *unified_buffer;
#else
return value;
#endif
}

// Gets the offset of a specific tensor for a specific layer in the GPT2 model
// layer_id is ignored for weights that are not part of a transformer block
ShardInfo gpt2_get_tensor_at_layer(const GPT2 *model, int layer_id, int param_tensor_id) {
Expand Down