From e89d7e58739dd950dd76cd0629d327e4d347e931 Mon Sep 17 00:00:00 2001
From: mtairum <mtairum@tenstorrent.com>
Date: Fri, 24 Jan 2025 12:02:30 +0000
Subject: [PATCH] #0: Updated attention wo dense matmul program config to
 increase perf and reduce time to first token

---
 models/demos/llama3/tt/model_config.py | 11 ++++++++---
 1 file changed, 8 insertions(+), 3 deletions(-)

diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py
index 0002654966a2..e6e337a96584 100644
--- a/models/demos/llama3/tt/model_config.py
+++ b/models/demos/llama3/tt/model_config.py
@@ -367,13 +367,18 @@ def find_largest_divisor(n, max_divisor=8):
                 else ((8, 8) if seq_len >= 1024 else (8, 4)),
             )
 
+            # TODO Check TTFT (prefill perf) on galaxy with the new config
             self.model_config["WO_PREFILL_PROGCFG"] = lambda seq_len: self.matmul_config(
                 m=min(seq_len, 1024 if self.is_galaxy else 2048),
                 k=self.dim // self.cluster_shape[0] if self.is_galaxy else self.dim,
-                n=self.dim // self.cluster_shape[1] if self.is_galaxy else self.dim,
+                n=self.dim // self.cluster_shape[1]
+                if self.is_galaxy
+                else 1024
+                if self.ccl_topology() == ttnn.Topology.Ring and 1024 % (self.dim / self.num_devices) == 0
+                else self.dim,
                 grid_size=(8, 8),
-                in0_block_w=1,
-                fuse_batch=seq_len <= 1024,  # if self.is_galaxy else 2048),
+                in0_block_w=1 if self.is_galaxy else self.dim // 1024,
+                fuse_batch=seq_len <= 1024,
             )
 
             # Calculate largest number of lm_head_num_rows such that self.dim % (lm_head_num_rows * 8) == 0