Skip to content

Commit

Permalink
#0: Avoid generating cos/sin matrices twice
Browse files Browse the repository at this point in the history
  • Loading branch information
mtairum committed Jan 30, 2025
1 parent 893979b commit ac0318e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 9 deletions.
9 changes: 2 additions & 7 deletions models/demos/llama3/tt/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,8 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag
tokens_embd = self.embd(tokens)
tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd)

tt_rot_mats_prefill = get_prefill_rot_mat(
self.args.head_dim,
self.mesh_device,
seq_len=S,
scale_factor=self.args.rope_scaling_factor,
start_pos=start_pos,
)
# Slice the rot mats to the prefill seqlen
tt_rot_mats_prefill = [self.rope_setup.cos_matrix[:, :, :S, :], self.rope_setup.sin_matrix[:, :, :S, :]]

if page_table is not None:
tt_page_table = ttnn.from_torch(
Expand Down
4 changes: 2 additions & 2 deletions models/demos/llama3/tt/llama_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ def __init__(
self.cos_matrix = ttnn.from_torch(
cos_matrix,
device=device,
layout=ttnn.ROW_MAJOR_LAYOUT,
layout=ttnn.TILE_LAYOUT,
dtype=datatype,
mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None,
)
self.sin_matrix = ttnn.from_torch(
sin_matrix,
device=device,
layout=ttnn.ROW_MAJOR_LAYOUT,
layout=ttnn.TILE_LAYOUT,
dtype=datatype,
mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None,
)
Expand Down

0 comments on commit ac0318e

Please sign in to comment.