From 2fd6fee971b14bb095e814a17cd8ead0bc192a5f Mon Sep 17 00:00:00 2001 From: ademeure Date: Sun, 7 Jul 2024 12:35:29 +0000 Subject: [PATCH] fix the corner case where T < 256 --- train_gpt2.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index c3ccb0553..920beb658 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1733,7 +1733,7 @@ int main(int argc, char *argv[]) { // on cuDNN 9.2.1 with cuDNN FrontEnd 1.5.2, T >= 256 seems bit-for-bit identical // (but even if it wasn't fully identical that's probably not the end of the world) // note this is still somewhat wasteful because we don't have a KV cache! - gpt2_forward(&model, gen_tokens, 1, CEIL_DIV(t, 256) * 256); + gpt2_forward(&model, gen_tokens, 1, CEIL_DIV(t, min(T,256)) * min(T,256)); // get the V-dimensional vector probs[0, t-1, :] floatX* logits = model.acts.output + (t - 1) * model.config.padded_vocab_size; // move probs back to CPU and sample (note we only move the first vocab_size logits, ignoring the padding)