From ac0318efdaad8da54f17dfe4fe240e90268f4ffd Mon Sep 17 00:00:00 2001 From: mtairum Date: Thu, 30 Jan 2025 17:43:02 +0000 Subject: [PATCH] #0: Avoid generating cos/sin matrices twice --- models/demos/llama3/tt/llama_model.py | 9 ++------- models/demos/llama3/tt/llama_rope.py | 4 ++-- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/models/demos/llama3/tt/llama_model.py b/models/demos/llama3/tt/llama_model.py index c40cd0315ff..a43e9b38195 100644 --- a/models/demos/llama3/tt/llama_model.py +++ b/models/demos/llama3/tt/llama_model.py @@ -120,13 +120,8 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag tokens_embd = self.embd(tokens) tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) - tt_rot_mats_prefill = get_prefill_rot_mat( - self.args.head_dim, - self.mesh_device, - seq_len=S, - scale_factor=self.args.rope_scaling_factor, - start_pos=start_pos, - ) + # Slice the rot mats to the prefill seqlen + tt_rot_mats_prefill = [self.rope_setup.cos_matrix[:, :, :S, :], self.rope_setup.sin_matrix[:, :, :S, :]] if page_table is not None: tt_page_table = ttnn.from_torch( diff --git a/models/demos/llama3/tt/llama_rope.py b/models/demos/llama3/tt/llama_rope.py index 06406a4eb2d..13449603ad9 100644 --- a/models/demos/llama3/tt/llama_rope.py +++ b/models/demos/llama3/tt/llama_rope.py @@ -55,14 +55,14 @@ def __init__( self.cos_matrix = ttnn.from_torch( cos_matrix, device=device, - layout=ttnn.ROW_MAJOR_LAYOUT, + layout=ttnn.TILE_LAYOUT, dtype=datatype, mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, ) self.sin_matrix = ttnn.from_torch( sin_matrix, device=device, - layout=ttnn.ROW_MAJOR_LAYOUT, + layout=ttnn.TILE_LAYOUT, dtype=datatype, mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, )