From e89d7e58739dd950dd76cd0629d327e4d347e931 Mon Sep 17 00:00:00 2001 From: mtairum <mtairum@tenstorrent.com> Date: Fri, 24 Jan 2025 12:02:30 +0000 Subject: [PATCH] #0: Updated attention wo dense matmul program config to increase perf and reduce time to first token --- models/demos/llama3/tt/model_config.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py index 0002654966a2..e6e337a96584 100644 --- a/models/demos/llama3/tt/model_config.py +++ b/models/demos/llama3/tt/model_config.py @@ -367,13 +367,18 @@ def find_largest_divisor(n, max_divisor=8): else ((8, 8) if seq_len >= 1024 else (8, 4)), ) + # TODO Check TTFT (prefill perf) on galaxy with the new config self.model_config["WO_PREFILL_PROGCFG"] = lambda seq_len: self.matmul_config( m=min(seq_len, 1024 if self.is_galaxy else 2048), k=self.dim // self.cluster_shape[0] if self.is_galaxy else self.dim, - n=self.dim // self.cluster_shape[1] if self.is_galaxy else self.dim, + n=self.dim // self.cluster_shape[1] + if self.is_galaxy + else 1024 + if self.ccl_topology() == ttnn.Topology.Ring and 1024 % (self.dim / self.num_devices) == 0 + else self.dim, grid_size=(8, 8), - in0_block_w=1, - fuse_batch=seq_len <= 1024, # if self.is_galaxy else 2048), + in0_block_w=1 if self.is_galaxy else self.dim // 1024, + fuse_batch=seq_len <= 1024, ) # Calculate largest number of lm_head_num_rows such that self.dim % (lm_head_num_rows * 8) == 0