Skip to content

Commit

Permalink
fix the corner case where T < 256
Browse files Browse the repository at this point in the history
  • Loading branch information
ademeure committed Jul 7, 2024
1 parent efa6767 commit 2fd6fee
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1733,7 +1733,7 @@ int main(int argc, char *argv[]) {
// on cuDNN 9.2.1 with cuDNN FrontEnd 1.5.2, T >= 256 seems bit-for-bit identical
// (but even if it wasn't fully identical that's probably not the end of the world)
// note this is still somewhat wasteful because we don't have a KV cache!
gpt2_forward(&model, gen_tokens, 1, CEIL_DIV(t, 256) * 256);
gpt2_forward(&model, gen_tokens, 1, CEIL_DIV(t, min(T,256)) * min(T,256));
// get the V-dimensional vector probs[0, t-1, :]
floatX* logits = model.acts.output + (t - 1) * model.config.padded_vocab_size;
// move probs back to CPU and sample (note we only move the first vocab_size logits, ignoring the padding)
Expand Down

0 comments on commit 2fd6fee

Please sign in to comment.