Skip to content

Commit

Permalink
#0: Updated attention wo dense matmul program config to increase perf…
Browse files Browse the repository at this point in the history
… and reduce time to first token
  • Loading branch information
mtairum committed Jan 24, 2025
1 parent 4236e3a commit e89d7e5
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions models/demos/llama3/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,13 +367,18 @@ def find_largest_divisor(n, max_divisor=8):
else ((8, 8) if seq_len >= 1024 else (8, 4)),
)

# TODO Check TTFT (prefill perf) on galaxy with the new config
self.model_config["WO_PREFILL_PROGCFG"] = lambda seq_len: self.matmul_config(
m=min(seq_len, 1024 if self.is_galaxy else 2048),
k=self.dim // self.cluster_shape[0] if self.is_galaxy else self.dim,
n=self.dim // self.cluster_shape[1] if self.is_galaxy else self.dim,
n=self.dim // self.cluster_shape[1]
if self.is_galaxy
else 1024
if self.ccl_topology() == ttnn.Topology.Ring and 1024 % (self.dim / self.num_devices) == 0
else self.dim,
grid_size=(8, 8),
in0_block_w=1,
fuse_batch=seq_len <= 1024, # if self.is_galaxy else 2048),
in0_block_w=1 if self.is_galaxy else self.dim // 1024,
fuse_batch=seq_len <= 1024,
)

# Calculate largest number of lm_head_num_rows such that self.dim % (lm_head_num_rows * 8) == 0
Expand Down

0 comments on commit e89d7e5

Please sign in to comment.