Skip to content

Commit

Permalink
#0: Fix llama test model
Browse files Browse the repository at this point in the history
  • Loading branch information
mtairum committed Jan 27, 2025
1 parent 38e4802 commit ec155d5
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions models/demos/llama3/tests/test_llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,11 @@ def test_llama_model_inference(
pt_decode_input = embd(encoded_prompts_tensor[:, i]).view(batch, seqlen, -1)
else:
# Greedy decode (temperature = 0) the generated token and save it to print out later
tt_out_tok = sample_host(tt_output_torch, None, temperature=0, top_p=0.8)
_, tt_out_tok = sample_host(tt_output_torch, None, temperature=0, top_p=0.8)
tt_decode_input = embd(tt_out_tok)
all_outputs.append(tt_out_tok.squeeze(1).tolist()[0]) # Update generated token to list of TT outputs
if run_ref_pt:
pt_out_tok = sample_host(ref_output, None, temperature=0, top_p=0.8)
_, pt_out_tok = sample_host(ref_output, None, temperature=0, top_p=0.8)
pt_decode_input = embd(pt_out_tok)
all_outputs_ref.append(
pt_out_tok.squeeze(1).tolist()[0]
Expand Down

0 comments on commit ec155d5

Please sign in to comment.