From 1075e20ba340cb2d205b50f1ed9f7fca8a4c9535 Mon Sep 17 00:00:00 2001 From: mtairum Date: Fri, 17 Jan 2025 17:25:19 +0000 Subject: [PATCH] #0: Minor fix --- models/demos/llama3/demo/simple_text_demo.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/models/demos/llama3/demo/simple_text_demo.py b/models/demos/llama3/demo/simple_text_demo.py index 2eabfc02cbab..567df355e478 100644 --- a/models/demos/llama3/demo/simple_text_demo.py +++ b/models/demos/llama3/demo/simple_text_demo.py @@ -475,9 +475,7 @@ def test_llama_demo_text( # Get the next token if argmax_on_device: - out_tok = logits - if out_tok.dim() == 1: - out_tok = out_tok.unsqueeze(0) + out_tok = logits.unsqueeze(1) else: # TODO Fix use case with temperature > 0 _, out_tok = sample_host(