Skip to content

Commit

Permalink
hqq unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Oct 25, 2024
1 parent b9106bd commit 5ef4b80
Showing 1 changed file with 82 additions and 0 deletions.
82 changes: 82 additions & 0 deletions tests/kernels/test_marlin_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
marlin_qqq_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
from vllm.scalar_type import scalar_types

ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
Expand Down Expand Up @@ -453,6 +454,87 @@ def test_awq_marlin_gemm(
assert max_diff < 0.04


@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("group_size", [64])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_hqq_marlin_gemm(
k_chunk,
n_chunk,
group_size,
mnk_factors,
use_fp32_reduce,
):
m_factor, n_factor, k_factor = mnk_factors

size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor

quant_type = scalar_types.uint4

a_input = rand_data((size_m, size_k))
dev = a_input.device

b_weight = torch.randint(0,
10, (size_n, size_k),
dtype=torch.uint8,
device=dev)
scale = rand_data((size_n, size_k // group_size))
zero = rand_data((size_n, size_k // group_size))

gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)

sort_indices = torch.empty(0, dtype=torch.int, device=dev)
marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n,
4).to(dev)
marlin_s = marlin_permute_scales(scale.transpose(1, 0), size_k, size_n,
group_size).to(dev)
marlin_zp = marlin_permute_scales(zero.transpose(1, 0), size_k, size_n,
group_size).to(dev)

g_idx = marlin_make_empty_g_idx(dev)
g_idx_sort_indices = marlin_make_empty_g_idx(dev)

workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)

output = ops.gptq_marlin_gemm(
a_input,
marlin_w_q,
marlin_s,
marlin_zp,
g_idx,
g_idx_sort_indices,
workspace.scratch,
quant_type,
a_input.shape[0],
b_weight.shape[0],
a_input.shape[1],
is_k_full=True,
has_zp=True,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=True,
)

b_flat = b_weight.reshape(-1, group_size)
zp_flat = zero.reshape(-1, 1)
s_flat = scale.reshape(-1, 1)
dequant = (b_flat - zp_flat) * s_flat

output_ref = torch.matmul(a_input,
dequant.reshape(b_weight.shape).transpose(1, 0))

torch.cuda.synchronize()

max_diff = compute_max_diff(output, output_ref)

assert max_diff < 0.04


@pytest.mark.skipif(not is_quant_method_supported("qqq"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
Expand Down

0 comments on commit 5ef4b80

Please sign in to comment.