diff --git a/models/demos/llama3/tests/test_llama_model.py b/models/demos/llama3/tests/test_llama_model.py index a41645f3394f..be419d8d884c 100644 --- a/models/demos/llama3/tests/test_llama_model.py +++ b/models/demos/llama3/tests/test_llama_model.py @@ -320,11 +320,11 @@ def test_llama_model_inference( pt_decode_input = embd(encoded_prompts_tensor[:, i]).view(batch, seqlen, -1) else: # Greedy decode (temperature = 0) the generated token and save it to print out later - tt_out_tok = sample_host(tt_output_torch, None, temperature=0, top_p=0.8) + _, tt_out_tok = sample_host(tt_output_torch, None, temperature=0, top_p=0.8) tt_decode_input = embd(tt_out_tok) all_outputs.append(tt_out_tok.squeeze(1).tolist()[0]) # Update generated token to list of TT outputs if run_ref_pt: - pt_out_tok = sample_host(ref_output, None, temperature=0, top_p=0.8) + _, pt_out_tok = sample_host(ref_output, None, temperature=0, top_p=0.8) pt_decode_input = embd(pt_out_tok) all_outputs_ref.append( pt_out_tok.squeeze(1).tolist()[0]