Skip to content

Commit

Permalink
[Bugfix] Machete garbage results for some models (large K dim) (vllm-…
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson authored Oct 10, 2024
1 parent ce00231 commit a64e7b9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
23 changes: 13 additions & 10 deletions csrc/quantization/machete/machete_mainloop.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -591,24 +591,27 @@ struct MacheteCollectiveMma {
tma_load_b = make_tma_copy_B(
make_logical_tensor(ptr_B, make_shape(N, K, L), args.dB));

int32_t scale_k =
(ModeHasScales) ? (K + args.group_size - 1) / args.group_size : 0;
int32_t group_size = (ModeHasScales) ? args.group_size : 0;

if constexpr (ModeHasScales) {
tma_load_scale = make_tma_copy_scale(make_logical_tensor(
args.ptr_S, make_shape(M, args.group_size, L), args.dS));
tma_load_scale = make_tma_copy_scale(
make_logical_tensor(args.ptr_S, make_shape(M, scale_k, L), args.dS));
}

if constexpr (KernelConversionMode ==
ConversionMode::ConvertAndScaleWithZero) {
tma_load_zero = make_tma_copy_zero(make_logical_tensor(
args.ptr_Z, make_shape(M, args.group_size, L), args.dS));
tma_load_zero = make_tma_copy_zero(
make_logical_tensor(args.ptr_Z, make_shape(M, scale_k, L), args.dS));
}

if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return {tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0};
} else if constexpr (ModeHasScales) {
auto scale_k = (K + args.group_size - 1) / args.group_size;

if constexpr (KernelConversionMode == ConversionMode::DirectConvert ||
KernelConversionMode == ConversionMode::ConvertAndScale ||
KernelConversionMode ==
ConversionMode::ConvertAndScaleWithZero) {
return {tma_load_a, tma_load_b, tma_load_scale,
tma_load_zero, scale_k, args.group_size};
tma_load_zero, scale_k, group_size};
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled in to_underlying_arguments.");
Expand Down
5 changes: 3 additions & 2 deletions tests/kernels/test_machete_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@
(1, 128, 128),
(1, 512, 1024),
(1, 4096, 4096),
(1, 8192, 28672),
(13, 8192, 4096),
(26, 4096, 8192),
(1, 4096, 4096),
(64, 4096, 4096),
(64, 8192, 28672),
(257, 128, 4096),
(257, 4224, 4160),
(257, 4096, 4096),
(64, 4096, 4096),
(1024, 4096, 8192),
(1024, 8192, 4096),
]
Expand Down

0 comments on commit a64e7b9

Please sign in to comment.