diff --git a/test_gpt2.cu b/test_gpt2.cu index e608ce229..512a83723 100644 --- a/test_gpt2.cu +++ b/test_gpt2.cu @@ -122,7 +122,7 @@ int main(int argc, char *argv[]) { if (argv[i][0] != '-') { exit(EXIT_FAILURE); } // must start with dash if (argv[i][1] == 'w') { model.use_master_weights = atoi(argv[i+1]); } else if (argv[i][1] == 'r') { model.recompute = atoi(argv[i+1]); } - else if (argv[i][1] == 'g' && argv[i][2] == 'e') { model.gelu_fusion = atoi(argv[i+1]); } + else if (argv[i][1] == 'g' && argv[i][2] == 'e') { model.act_func_fusion = atoi(argv[i+1]); } } // load additional information that we will use for debugging and error checking