diff --git a/train_gpt2.cu b/train_gpt2.cu index 1e8db96ea..50b2cb7fa 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -376,7 +376,7 @@ void gpt2_allocate_state(GPT2 *model, int B, int T) { size_t param_elements[NUM_PARAMETER_TENSORS]; size_t param_sizeof[NUM_PARAMETER_TENSORS]; GPT2Config wave_config = model->config; - wave_config.num_layers = 1; + wave_config.num_layers = 2; fill_in_parameter_sizes(param_elements, param_sizeof, wave_config); size_t alloc_bytes = 0; for(int i = 0; i < NUM_PARAMETER_TENSORS; ++i) {