diff --git a/models/demos/llama3/demo/simple_text_demo.py b/models/demos/llama3/demo/simple_text_demo.py index 2eabfc02cbab..567df355e478 100644 --- a/models/demos/llama3/demo/simple_text_demo.py +++ b/models/demos/llama3/demo/simple_text_demo.py @@ -475,9 +475,7 @@ def test_llama_demo_text( # Get the next token if argmax_on_device: - out_tok = logits - if out_tok.dim() == 1: - out_tok = out_tok.unsqueeze(0) + out_tok = logits.unsqueeze(1) else: # TODO Fix use case with temperature > 0 _, out_tok = sample_host(