Skip to content

Commit

Permalink
tweak for llama 405b
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Sep 7, 2024
1 parent a4a5a38 commit af66815
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions csrc/quantization/machete/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def generate():
# For now we use the same heuristic for all types
default_heuristic = [
#### 257+
("M > 256 && N <= 6144 && K <= 6144",
("M > 256 && K <= 16384 && N <= 4096",
ScheduleConfig(
tile_shape_mn=(128, 128),
cluster_shape_mnk=(2, 1, 1),
Expand All @@ -351,7 +351,13 @@ def generate():
**schedule_common_params # type: ignore
)),
#### 129-256
("M > 128 && N <= 8192 && K <= 8192",
("M > 128 && K <= 4096 && N <= 4096",
ScheduleConfig(
tile_shape_mn=(128, 64),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
("M > 128 && K <= 8192 && N <= 8192",
ScheduleConfig(
tile_shape_mn=(128, 128),
cluster_shape_mnk=(2, 1, 1),
Expand Down Expand Up @@ -395,7 +401,7 @@ def generate():
cluster_shape_mnk=(1, 1, 1),
**schedule_common_params # type: ignore
)),
("M > 32 && K >= 8192 && N >= 12288",
("M > 32 && K >= 16384 && N >= 12288",
ScheduleConfig(
tile_shape_mn=(256, 64),
cluster_shape_mnk=(2, 1, 1),
Expand All @@ -408,7 +414,7 @@ def generate():
**schedule_common_params # type: ignore
)),
#### 17-32
("M > 16 && N <= 8192 && K <= 12288",
("M > 16 && K <= 12288 && N <= 8192",
ScheduleConfig(
tile_shape_mn=(128, 32),
cluster_shape_mnk=(2, 1, 1),
Expand All @@ -421,7 +427,7 @@ def generate():
**schedule_common_params # type: ignore
)),
#### 1-16
("M <= 16 && N >= 28672",
("N >= 26624",
ScheduleConfig(
tile_shape_mn=(256, 16),
cluster_shape_mnk=(1, 1, 1),
Expand Down

0 comments on commit af66815

Please sign in to comment.