diff --git a/torchchat/generate.py b/torchchat/generate.py index ad933687d..f48543499 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -576,6 +576,7 @@ def decode_n_tokens( **sampling_kwargs, ) input_pos += 1 + yield cur_token.clone(), next_prob.clone() break if not encountered_eos: