diff --git a/models/demos/llama3/demo/simple_text_demo.py b/models/demos/llama3/demo/simple_text_demo.py
index 2eabfc02cbab..567df355e478 100644
--- a/models/demos/llama3/demo/simple_text_demo.py
+++ b/models/demos/llama3/demo/simple_text_demo.py
@@ -475,9 +475,7 @@ def test_llama_demo_text(
 
             # Get the next token
             if argmax_on_device:
-                out_tok = logits
-                if out_tok.dim() == 1:
-                    out_tok = out_tok.unsqueeze(0)
+                out_tok = logits.unsqueeze(1)
             else:
                 # TODO Fix use case with temperature > 0
                 _, out_tok = sample_host(