From fff90f81241ed932023cb0d34f8f8bb70ef6eaa2 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Thu, 25 Apr 2024 13:13:03 +0300 Subject: [PATCH] reorder weights according to their precision --- train_gpt2.cu | 104 ++++++++++++++++++++++++-------------------------- train_gpt2.py | 18 +++++---- 2 files changed, 60 insertions(+), 62 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index 76f4cef25..996a75da4 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1325,25 +1325,29 @@ typedef struct { } GPT2Config; // the parameters of the model -#define NUM_PARAMETER_TENSORS 16 +constexpr const int NUM_PARAMETER_TENSORS = 16; typedef struct { + // matrices in lower precision floatX* wte; // (V, C) floatX* wpe; // (maxT, C) - floatN* ln1w; // (L, C) - floatN* ln1b; // (L, C) floatX* qkvw; // (L, 3*C, C) floatX* qkvb; // (L, 3*C) floatX* attprojw; // (L, C, C) floatX* attprojb; // (L, C) - floatN* ln2w; // (L, C) - floatN* ln2b; // (L, C) floatX* fcw; // (L, 4*C, C) floatX* fcb; // (L, 4*C) floatX* fcprojw; // (L, C, 4*C) floatX* fcprojb; // (L, C) + + // layernorm parameters in higher precision + floatN* ln1w; // (L, C) + floatN* ln1b; // (L, C) + floatN* ln2w; // (L, C) + floatN* ln2b; // (L, C) floatN* lnfw; // (C) floatN* lnfb; // (C) } ParameterTensors; +static_assert(sizeof(ParameterTensors) == NUM_PARAMETER_TENSORS * sizeof(void*), "Inconsistent sizes!"); void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Config config) { int V = config.vocab_size; @@ -1352,18 +1356,18 @@ void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Conf int L = config.num_layers; param_sizes[0] = V * C; // wte param_sizes[1] = maxT * C; // wpe - param_sizes[2] = L * C; // ln1w - param_sizes[3] = L * C; // ln1b - param_sizes[4] = L * (3 * C) * C; // qkvw - param_sizes[5] = L * (3 * C); // qkvb - param_sizes[6] = L * C * C; // attprojw - param_sizes[7] = L * C; // attprojb - param_sizes[8] = L * C; // ln2w - param_sizes[9] = L * C; // ln2b - param_sizes[10] = L * (4 * C) * C; // fcw - param_sizes[11] = L * (4 * C); // fcb - param_sizes[12] = L * C * (4 * C); // fcprojw - param_sizes[13] = L * C; // fcprojb + param_sizes[2] = L * (3 * C) * C; // qkvw + param_sizes[3] = L * (3 * C); // qkvb + param_sizes[4] = L * C * C; // attprojw + param_sizes[5] = L * C; // attprojb + param_sizes[6] = L * (4 * C) * C; // fcw + param_sizes[7] = L * (4 * C); // fcb + param_sizes[8] = L * C * (4 * C); // fcprojw + param_sizes[9] = L * C; // fcprojb + param_sizes[10] = L * C; // ln1w + param_sizes[11] = L * C; // ln1b + param_sizes[12] = L * C; // ln2w + param_sizes[13] = L * C; // ln2b param_sizes[14] = C; // lnfw param_sizes[15] = C; // lnfb @@ -1372,10 +1376,10 @@ void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Conf for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { param_sizeof[i] = sizeof(floatX); } - param_sizeof[2] = sizeof(floatN); // ln1w - param_sizeof[3] = sizeof(floatN); // ln1b - param_sizeof[8] = sizeof(floatN); // ln2w - param_sizeof[9] = sizeof(floatN); // ln2b + param_sizeof[10] = sizeof(floatN); // ln1w + param_sizeof[11] = sizeof(floatN); // ln1b + param_sizeof[12] = sizeof(floatN); // ln2w + param_sizeof[13] = sizeof(floatN); // ln2b param_sizeof[14] = sizeof(floatN); // lnfw param_sizeof[15] = sizeof(floatN); // lnfb } @@ -1398,9 +1402,12 @@ float* malloc_and_point_parameters(ParameterTensors* params, size_t* param_eleme } // assign all the tensors their place in the array floatX** ptrs[] = { - ¶ms->wte, ¶ms->wpe, (floatX**)¶ms->ln1w, (floatX**)¶ms->ln1b, ¶ms->qkvw, ¶ms->qkvb, - ¶ms->attprojw, ¶ms->attprojb, (floatX**)¶ms->ln2w, (floatX**)¶ms->ln2b, ¶ms->fcw, ¶ms->fcb, - ¶ms->fcprojw, ¶ms->fcprojb, (floatX**)¶ms->lnfw, (floatX**)¶ms->lnfb + ¶ms->wte, ¶ms->wpe, ¶ms->qkvw, ¶ms->qkvb, + ¶ms->attprojw, ¶ms->attprojb, ¶ms->fcw, ¶ms->fcb, + ¶ms->fcprojw, ¶ms->fcprojb, + (floatX**)¶ms->ln1w, (floatX**)¶ms->ln1b, + (floatX**)¶ms->ln2w, (floatX**)¶ms->ln2b, + (floatX**)¶ms->lnfw, (floatX**)¶ms->lnfb }; char* params_memory_iterator = (char*)params_memory; for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) { @@ -1794,11 +1801,9 @@ void gpt2_backward(GPT2 *model) { model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_elements, model->param_sizeof, 1); printf("allocated %d MiB for parameter gradients\n", (int)round(model->num_parameters * sizeof(floatX) / (1024 * 1024))); // we're going to be clever for the activations backward pass. we don't need to exactly - // mirror the forward pass acrtivations and we will save memory. + // mirror the forward pass activations and we will save memory. size_t bw_act_sizes[NUM_ACTIVATION_TENSORS]; - GPT2Config cfg = model->config; - cfg.num_layers = 1; // copy the configuration but override number of layers to 1 - fill_in_grad_act_sizes(bw_act_sizes, model->batch_size, model->seq_len, cfg); + fill_in_grad_act_sizes(bw_act_sizes, model->batch_size, model->seq_len, model->config); // count up and allocate the space model->grads_acts_memory = malloc_and_point_backward(&model->grads_acts, bw_act_sizes); model->num_grad_acts = 0; @@ -1922,38 +1927,29 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo float beta2_correction = 1.0f - powf(beta2, t); // Do adam per set of parameters - // We need to know the parameter types (float or floatX) to process consecutive chunks - // TODO - optimise this to require fewer kernel launches and/or independent via CUDA streams char* params_mem = (char*)model->params_memory; char* grads_mem = (char*)model->grads_memory; size_t num_elements = model->param_elements[0]; size_t last_sizeof = model->param_sizeof[0]; size_t current_element = 0; + // adam update for the floatX weights; + + unsigned int seed = random_u32(&model->rng_state); // seed for stochastic rounding + size_t n_floatx = (floatX*) model->params.ln1w - model->params.wte; + int num_blocks = CEIL_DIV(n_floatx, block_size); + adamw_kernel3<<>>(model->params.wte, model->grads.wte, + model->m_memory, model->v_memory, n_floatx, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, + weight_decay, seed); + cudaCheck(cudaGetLastError()); - for (size_t i = 1; i <= NUM_PARAMETER_TENSORS; i++) { - if (i == NUM_PARAMETER_TENSORS || model->param_sizeof[i] != last_sizeof) { - unsigned int seed = random_u32(&model->rng_state); // seed for stochastic rounding - int num_blocks = CEIL_DIV(num_elements, block_size); - - if (last_sizeof == sizeof(floatX)) { - adamw_kernel3<<>>((floatX*)params_mem, (floatX*)grads_mem, - &model->m_memory[current_element], &model->v_memory[current_element], num_elements, - learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, seed); - } else { - adamw_kernel3<<>>((float*)params_mem, (float*)grads_mem, - &model->m_memory[current_element], &model->v_memory[current_element], num_elements, - learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, seed); - } - params_mem += num_elements * last_sizeof; - grads_mem += num_elements * last_sizeof; - current_element += num_elements; - num_elements = 0; - } - if (i != NUM_PARAMETER_TENSORS) { - num_elements += model->param_elements[i]; - last_sizeof = model->param_sizeof[i]; - } - } + // update for the floatn weights + size_t n_floatn = model->num_parameters - n_floatx; + num_blocks = CEIL_DIV(n_floatn, block_size); + adamw_kernel3<<>>(model->params.ln1w, model->grads.ln1w, + model->m_memory + n_floatx, model->v_memory + n_floatx, n_floatn, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, + weight_decay, seed); cudaCheck(cudaGetLastError()); } diff --git a/train_gpt2.py b/train_gpt2.py index d52b25c3d..542b585f7 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -220,10 +220,6 @@ def write_fp32(tensor, file): def write_tensors(model_tensors, L, file): write_fp32(model_tensors["transformer.wte.weight"], file) # (V, C) write_fp32(model_tensors["transformer.wpe.weight"], file) # (T, C) - for i in range(L): # (L, C) - write_fp32(model_tensors[f"transformer.h.{i}.ln_1.weight"], file) - for i in range(L): # (L, C) - write_fp32(model_tensors[f"transformer.h.{i}.ln_1.bias"], file) for i in range(L): # (L, 3C, C) write_fp32(model_tensors[f"transformer.h.{i}.attn.c_attn.weight"], file) for i in range(L): # (L, 3C) @@ -232,10 +228,6 @@ def write_tensors(model_tensors, L, file): write_fp32(model_tensors[f"transformer.h.{i}.attn.c_proj.weight"], file) for i in range(L): # (L, C) write_fp32(model_tensors[f"transformer.h.{i}.attn.c_proj.bias"], file) - for i in range(L): # (L, C) - write_fp32(model_tensors[f"transformer.h.{i}.ln_2.weight"], file) - for i in range(L): # (L, C) - write_fp32(model_tensors[f"transformer.h.{i}.ln_2.bias"], file) for i in range(L): # (L, 4C, C) write_fp32(model_tensors[f"transformer.h.{i}.mlp.c_fc.weight"], file) for i in range(L): # (L, 4C) @@ -244,6 +236,16 @@ def write_tensors(model_tensors, L, file): write_fp32(model_tensors[f"transformer.h.{i}.mlp.c_proj.weight"], file) for i in range(L): # (L, C) write_fp32(model_tensors[f"transformer.h.{i}.mlp.c_proj.bias"], file) + + for i in range(L): # (L, C) + write_fp32(model_tensors[f"transformer.h.{i}.ln_1.weight"], file) + for i in range(L): # (L, C) + write_fp32(model_tensors[f"transformer.h.{i}.ln_1.bias"], file) + for i in range(L): # (L, C) + write_fp32(model_tensors[f"transformer.h.{i}.ln_2.weight"], file) + for i in range(L): # (L, C) + write_fp32(model_tensors[f"transformer.h.{i}.ln_2.bias"], file) + write_fp32(model_tensors["transformer.ln_f.weight"], file) # (C, ) write_fp32(model_tensors["transformer.ln_f.bias"], file) # (C, )