Skip to content

Commit

Permalink
reorder weights according to their precision
Browse files Browse the repository at this point in the history
  • Loading branch information
ngc92 committed Apr 25, 2024
1 parent 716a2ed commit fff90f8
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 62 deletions.
104 changes: 50 additions & 54 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1325,25 +1325,29 @@ typedef struct {
} GPT2Config;

// the parameters of the model
#define NUM_PARAMETER_TENSORS 16
constexpr const int NUM_PARAMETER_TENSORS = 16;
typedef struct {
// matrices in lower precision
floatX* wte; // (V, C)
floatX* wpe; // (maxT, C)
floatN* ln1w; // (L, C)
floatN* ln1b; // (L, C)
floatX* qkvw; // (L, 3*C, C)
floatX* qkvb; // (L, 3*C)
floatX* attprojw; // (L, C, C)
floatX* attprojb; // (L, C)
floatN* ln2w; // (L, C)
floatN* ln2b; // (L, C)
floatX* fcw; // (L, 4*C, C)
floatX* fcb; // (L, 4*C)
floatX* fcprojw; // (L, C, 4*C)
floatX* fcprojb; // (L, C)

// layernorm parameters in higher precision
floatN* ln1w; // (L, C)
floatN* ln1b; // (L, C)
floatN* ln2w; // (L, C)
floatN* ln2b; // (L, C)
floatN* lnfw; // (C)
floatN* lnfb; // (C)
} ParameterTensors;
static_assert(sizeof(ParameterTensors) == NUM_PARAMETER_TENSORS * sizeof(void*), "Inconsistent sizes!");

void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Config config) {
int V = config.vocab_size;
Expand All @@ -1352,18 +1356,18 @@ void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Conf
int L = config.num_layers;
param_sizes[0] = V * C; // wte
param_sizes[1] = maxT * C; // wpe
param_sizes[2] = L * C; // ln1w
param_sizes[3] = L * C; // ln1b
param_sizes[4] = L * (3 * C) * C; // qkvw
param_sizes[5] = L * (3 * C); // qkvb
param_sizes[6] = L * C * C; // attprojw
param_sizes[7] = L * C; // attprojb
param_sizes[8] = L * C; // ln2w
param_sizes[9] = L * C; // ln2b
param_sizes[10] = L * (4 * C) * C; // fcw
param_sizes[11] = L * (4 * C); // fcb
param_sizes[12] = L * C * (4 * C); // fcprojw
param_sizes[13] = L * C; // fcprojb
param_sizes[2] = L * (3 * C) * C; // qkvw
param_sizes[3] = L * (3 * C); // qkvb
param_sizes[4] = L * C * C; // attprojw
param_sizes[5] = L * C; // attprojb
param_sizes[6] = L * (4 * C) * C; // fcw
param_sizes[7] = L * (4 * C); // fcb
param_sizes[8] = L * C * (4 * C); // fcprojw
param_sizes[9] = L * C; // fcprojb
param_sizes[10] = L * C; // ln1w
param_sizes[11] = L * C; // ln1b
param_sizes[12] = L * C; // ln2w
param_sizes[13] = L * C; // ln2b
param_sizes[14] = C; // lnfw
param_sizes[15] = C; // lnfb

Expand All @@ -1372,10 +1376,10 @@ void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Conf
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
param_sizeof[i] = sizeof(floatX);
}
param_sizeof[2] = sizeof(floatN); // ln1w
param_sizeof[3] = sizeof(floatN); // ln1b
param_sizeof[8] = sizeof(floatN); // ln2w
param_sizeof[9] = sizeof(floatN); // ln2b
param_sizeof[10] = sizeof(floatN); // ln1w
param_sizeof[11] = sizeof(floatN); // ln1b
param_sizeof[12] = sizeof(floatN); // ln2w
param_sizeof[13] = sizeof(floatN); // ln2b
param_sizeof[14] = sizeof(floatN); // lnfw
param_sizeof[15] = sizeof(floatN); // lnfb
}
Expand All @@ -1398,9 +1402,12 @@ float* malloc_and_point_parameters(ParameterTensors* params, size_t* param_eleme
}
// assign all the tensors their place in the array
floatX** ptrs[] = {
&params->wte, &params->wpe, (floatX**)&params->ln1w, (floatX**)&params->ln1b, &params->qkvw, &params->qkvb,
&params->attprojw, &params->attprojb, (floatX**)&params->ln2w, (floatX**)&params->ln2b, &params->fcw, &params->fcb,
&params->fcprojw, &params->fcprojb, (floatX**)&params->lnfw, (floatX**)&params->lnfb
&params->wte, &params->wpe, &params->qkvw, &params->qkvb,
&params->attprojw, &params->attprojb, &params->fcw, &params->fcb,
&params->fcprojw, &params->fcprojb,
(floatX**)&params->ln1w, (floatX**)&params->ln1b,
(floatX**)&params->ln2w, (floatX**)&params->ln2b,
(floatX**)&params->lnfw, (floatX**)&params->lnfb
};
char* params_memory_iterator = (char*)params_memory;
for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) {
Expand Down Expand Up @@ -1794,11 +1801,9 @@ void gpt2_backward(GPT2 *model) {
model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_elements, model->param_sizeof, 1);
printf("allocated %d MiB for parameter gradients\n", (int)round(model->num_parameters * sizeof(floatX) / (1024 * 1024)));
// we're going to be clever for the activations backward pass. we don't need to exactly
// mirror the forward pass acrtivations and we will save memory.
// mirror the forward pass activations and we will save memory.
size_t bw_act_sizes[NUM_ACTIVATION_TENSORS];
GPT2Config cfg = model->config;
cfg.num_layers = 1; // copy the configuration but override number of layers to 1
fill_in_grad_act_sizes(bw_act_sizes, model->batch_size, model->seq_len, cfg);
fill_in_grad_act_sizes(bw_act_sizes, model->batch_size, model->seq_len, model->config);
// count up and allocate the space
model->grads_acts_memory = malloc_and_point_backward(&model->grads_acts, bw_act_sizes);
model->num_grad_acts = 0;
Expand Down Expand Up @@ -1922,38 +1927,29 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo
float beta2_correction = 1.0f - powf(beta2, t);

// Do adam per set of parameters
// We need to know the parameter types (float or floatX) to process consecutive chunks
// TODO - optimise this to require fewer kernel launches and/or independent via CUDA streams
char* params_mem = (char*)model->params_memory;
char* grads_mem = (char*)model->grads_memory;
size_t num_elements = model->param_elements[0];
size_t last_sizeof = model->param_sizeof[0];
size_t current_element = 0;
// adam update for the floatX weights;

unsigned int seed = random_u32(&model->rng_state); // seed for stochastic rounding
size_t n_floatx = (floatX*) model->params.ln1w - model->params.wte;
int num_blocks = CEIL_DIV(n_floatx, block_size);
adamw_kernel3<<<num_blocks, block_size>>>(model->params.wte, model->grads.wte,
model->m_memory, model->v_memory, n_floatx,
learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps,
weight_decay, seed);
cudaCheck(cudaGetLastError());

for (size_t i = 1; i <= NUM_PARAMETER_TENSORS; i++) {
if (i == NUM_PARAMETER_TENSORS || model->param_sizeof[i] != last_sizeof) {
unsigned int seed = random_u32(&model->rng_state); // seed for stochastic rounding
int num_blocks = CEIL_DIV(num_elements, block_size);

if (last_sizeof == sizeof(floatX)) {
adamw_kernel3<<<num_blocks, block_size>>>((floatX*)params_mem, (floatX*)grads_mem,
&model->m_memory[current_element], &model->v_memory[current_element], num_elements,
learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, seed);
} else {
adamw_kernel3<<<num_blocks, block_size>>>((float*)params_mem, (float*)grads_mem,
&model->m_memory[current_element], &model->v_memory[current_element], num_elements,
learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, seed);
}
params_mem += num_elements * last_sizeof;
grads_mem += num_elements * last_sizeof;
current_element += num_elements;
num_elements = 0;
}
if (i != NUM_PARAMETER_TENSORS) {
num_elements += model->param_elements[i];
last_sizeof = model->param_sizeof[i];
}
}
// update for the floatn weights
size_t n_floatn = model->num_parameters - n_floatx;
num_blocks = CEIL_DIV(n_floatn, block_size);
adamw_kernel3<<<num_blocks, block_size>>>(model->params.ln1w, model->grads.ln1w,
model->m_memory + n_floatx, model->v_memory + n_floatx, n_floatn,
learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps,
weight_decay, seed);
cudaCheck(cudaGetLastError());
}

Expand Down
18 changes: 10 additions & 8 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,6 @@ def write_fp32(tensor, file):
def write_tensors(model_tensors, L, file):
write_fp32(model_tensors["transformer.wte.weight"], file) # (V, C)
write_fp32(model_tensors["transformer.wpe.weight"], file) # (T, C)
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.ln_1.weight"], file)
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.ln_1.bias"], file)
for i in range(L): # (L, 3C, C)
write_fp32(model_tensors[f"transformer.h.{i}.attn.c_attn.weight"], file)
for i in range(L): # (L, 3C)
Expand All @@ -232,10 +228,6 @@ def write_tensors(model_tensors, L, file):
write_fp32(model_tensors[f"transformer.h.{i}.attn.c_proj.weight"], file)
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.attn.c_proj.bias"], file)
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.ln_2.weight"], file)
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.ln_2.bias"], file)
for i in range(L): # (L, 4C, C)
write_fp32(model_tensors[f"transformer.h.{i}.mlp.c_fc.weight"], file)
for i in range(L): # (L, 4C)
Expand All @@ -244,6 +236,16 @@ def write_tensors(model_tensors, L, file):
write_fp32(model_tensors[f"transformer.h.{i}.mlp.c_proj.weight"], file)
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.mlp.c_proj.bias"], file)

for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.ln_1.weight"], file)
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.ln_1.bias"], file)
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.ln_2.weight"], file)
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.ln_2.bias"], file)

write_fp32(model_tensors["transformer.ln_f.weight"], file) # (C, )
write_fp32(model_tensors["transformer.ln_f.bias"], file) # (C, )

Expand Down

0 comments on commit fff90f8

Please sign in to comment.