Skip to content

Commit

Permalink
force fp16 type in kernel to reduce wheel size
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Oct 25, 2024
1 parent 5340ce8 commit 1da2d97
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 14 deletions.
30 changes: 19 additions & 11 deletions csrc/quantization/gptq_marlin/gptq_marlin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1774,17 +1774,19 @@ __global__ void Marlin(
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
cudaFuncSetAttribute( \
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
if constexpr (!IS_ZP_FLOAT || std::is_same<scalar_t, half>::value) { \
cudaFuncSetAttribute( \
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, \
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
} \
}

typedef struct {
Expand Down Expand Up @@ -2273,6 +2275,12 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
b_q_type.str());
}

if (has_zp && is_zp_float) {
TORCH_CHECK(a.scalar_type() == at::ScalarType::Half,
"Computation type must be float16 (half) when using float zero "
"points.");
}

int pack_factor = 32 / b_q_type.size_bits();

// Verify A
Expand Down
18 changes: 15 additions & 3 deletions vllm/model_executor/layers/quantization/hqq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,20 @@ def apply(
GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)

scales = layer.marlin_scales
zeros = layer.marlin_zeros
orig_type = x.dtype

if orig_type != torch.float16:
x = x.to(torch.float16)
scales = scales.to(torch.float16)
zeros = zeros.to(torch.float16)

marlin_out = ops.gptq_marlin_gemm(
x,
layer.marlin_qweight,
layer.marlin_scales,
layer.marlin_zeros,
scales,
zeros,
layer.g_idx,
layer.g_idx_sort_indices,
workspace.scratch,
Expand All @@ -203,4 +212,7 @@ def apply(
if bias is not None:
marlin_out.add_(bias)

return marlin_out
if orig_type != torch.float16:
return marlin_out.to(orig_type)
else:
return marlin_out

0 comments on commit 1da2d97

Please sign in to comment.