From af66815b6378ab09b4c6076c25266c83d2e333d4 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 6 Sep 2024 05:39:04 +0000 Subject: [PATCH] tweak for llama 405b --- csrc/quantization/machete/generate.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index d1b83268ab711..2f439d079c62e 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -338,7 +338,7 @@ def generate(): # For now we use the same heuristic for all types default_heuristic = [ #### 257+ - ("M > 256 && N <= 6144 && K <= 6144", + ("M > 256 && K <= 16384 && N <= 4096", ScheduleConfig( tile_shape_mn=(128, 128), cluster_shape_mnk=(2, 1, 1), @@ -351,7 +351,13 @@ def generate(): **schedule_common_params # type: ignore )), #### 129-256 - ("M > 128 && N <= 8192 && K <= 8192", + ("M > 128 && K <= 4096 && N <= 4096", + ScheduleConfig( + tile_shape_mn=(128, 64), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ("M > 128 && K <= 8192 && N <= 8192", ScheduleConfig( tile_shape_mn=(128, 128), cluster_shape_mnk=(2, 1, 1), @@ -395,7 +401,7 @@ def generate(): cluster_shape_mnk=(1, 1, 1), **schedule_common_params # type: ignore )), - ("M > 32 && K >= 8192 && N >= 12288", + ("M > 32 && K >= 16384 && N >= 12288", ScheduleConfig( tile_shape_mn=(256, 64), cluster_shape_mnk=(2, 1, 1), @@ -408,7 +414,7 @@ def generate(): **schedule_common_params # type: ignore )), #### 17-32 - ("M > 16 && N <= 8192 && K <= 12288", + ("M > 16 && K <= 12288 && N <= 8192", ScheduleConfig( tile_shape_mn=(128, 32), cluster_shape_mnk=(2, 1, 1), @@ -421,7 +427,7 @@ def generate(): **schedule_common_params # type: ignore )), #### 1-16 - ("M <= 16 && N >= 28672", + ("N >= 26624", ScheduleConfig( tile_shape_mn=(256, 16), cluster_shape_mnk=(1, 1, 1),