diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 2ff437cfc94b9..7bd246d73b1e4 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -1774,17 +1774,19 @@ __global__ void Marlin( has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ is_zp_float == IS_ZP_FLOAT) { \ - cudaFuncSetAttribute( \ - Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ - <<>>( \ - A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \ - num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \ + if constexpr (!IS_ZP_FLOAT || std::is_same::value) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin \ + <<>>( \ + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \ + num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \ + } \ } typedef struct { @@ -2273,6 +2275,12 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, b_q_type.str()); } + if (has_zp && is_zp_float) { + TORCH_CHECK(a.scalar_type() == at::ScalarType::Half, + "Computation type must be float16 (half) when using float zero " + "points."); + } + int pack_factor = 32 / b_q_type.size_bits(); // Verify A diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index ecc88a6e1ea87..35c4cb00fb298 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -182,11 +182,20 @@ def apply( GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL) + scales = layer.marlin_scales + zeros = layer.marlin_zeros + orig_type = x.dtype + + if orig_type != torch.float16: + x = x.to(torch.float16) + scales = scales.to(torch.float16) + zeros = zeros.to(torch.float16) + marlin_out = ops.gptq_marlin_gemm( x, layer.marlin_qweight, - layer.marlin_scales, - layer.marlin_zeros, + scales, + zeros, layer.g_idx, layer.g_idx_sort_indices, workspace.scratch, @@ -203,4 +212,7 @@ def apply( if bias is not None: marlin_out.add_(bias) - return marlin_out + if orig_type != torch.float16: + return marlin_out.to(orig_type) + else: + return marlin_out