Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hqq support #21

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion benchmarks/kernels/benchmark_machete.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
size_m=a.shape[0],
size_n=w_ref.shape[1],
size_k=w_ref.shape[0],
is_k_full=True))))
is_k_full=True,
is_zp_float=False))))

# machete
timers.append(
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/kernels/benchmark_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
Expand All @@ -141,7 +141,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
Expand Down
271 changes: 203 additions & 68 deletions csrc/quantization/gptq_marlin/gptq_marlin.cu

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions csrc/quantization/gptq_marlin/marlin_dtypes.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class ScalarType<half> {
using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>;
using FragZP = Vec<half2, 4>;
using FragZPF = Vec<half2, 1>;

static __device__ float inline num2float(const half x) {
return __half2float(x);
Expand Down Expand Up @@ -53,6 +54,7 @@ class ScalarType<nv_bfloat16> {
using FragC = Vec<float, 4>;
using FragS = Vec<nv_bfloat162, 1>;
using FragZP = Vec<nv_bfloat162, 4>;
using FragZPF = Vec<nv_bfloat162, 1>;

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) {
Expand Down
2 changes: 1 addition & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
"int size_m, int size_n, int size_k, bool is_k_full, "
"bool has_zp, bool use_fp32_reduce) -> Tensor");
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
// conditionally compiled so impl registration is in source file

// gptq_marlin repack from GPTQ.
Expand Down
4 changes: 3 additions & 1 deletion tests/kernels/test_marlin_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def test_gptq_marlin_gemm(
torch.ops._C.gptq_marlin_gemm,
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
workspace.scratch, quant_type, a_input.shape[0], b_weight.shape[1],
a_input.shape[1], is_k_full, False, use_fp32_reduce),
a_input.shape[1], is_k_full, False, use_fp32_reduce, False),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)

output = ops.gptq_marlin_gemm(
Expand All @@ -244,6 +244,7 @@ def test_gptq_marlin_gemm(
is_k_full=is_k_full,
has_zp=False,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
output_ref = torch.matmul(a_input, w_ref)

Expand Down Expand Up @@ -431,6 +432,7 @@ def test_awq_marlin_gemm(
is_k_full=is_k_full,
has_zp=has_zp,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
output_ref = torch.matmul(a_input, w_ref)

Expand Down
8 changes: 5 additions & 3 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,8 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor,
size_k: int,
is_k_full: bool,
has_zp: bool = False,
use_fp32_reduce: bool = False) -> torch.Tensor:
use_fp32_reduce: bool = False,
is_zp_float: bool = False) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

@torch.library.register_fake("_C::ggml_dequantize")
Expand Down Expand Up @@ -595,11 +596,12 @@ def gptq_marlin_gemm(a: torch.Tensor,
size_k: int,
is_k_full: bool,
has_zp: bool = False,
use_fp32_reduce: bool = False) -> torch.Tensor:
use_fp32_reduce: bool = False,
is_zp_float: bool = False) -> torch.Tensor:
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
g_idx, perm, workspace, b_q_type,
size_m, size_n, size_k, is_k_full,
has_zp, use_fp32_reduce)
has_zp, use_fp32_reduce, is_zp_float)


# fp8 marlin
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def forward(self, input_):
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size_per_partition}"
s += f", bias={self.bias is not None}"
s += f", bias={hasattr(self, 'bias') and self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}"
s += f", gather_output={self.gather_output}"
return s
Expand Down Expand Up @@ -1068,7 +1068,7 @@ def forward(self, input_):
def extra_repr(self) -> str:
s = f"input_features={self.input_size_per_partition}"
s += f", output_features={self.output_size}"
s += f", bias={self.bias is not None}"
s += f", bias={hasattr(self, 'bias') and self.bias is not None}"
s += f", tp_size={self.tp_size}"
s += f", reduce_results={self.reduce_results}"
return s
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config)
from vllm.model_executor.layers.quantization.hqq_marlin import HQQMarlinConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config
from vllm.model_executor.layers.quantization.neuron_quant import (
Expand All @@ -47,6 +48,7 @@
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"hqq_marlin": HQQMarlinConfig,
"experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
}
Expand Down
206 changes: 206 additions & 0 deletions vllm/model_executor/layers/quantization/hqq_marlin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
from typing import Any, Dict, List, Optional

import torch

from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
marlin_make_empty_g_idx, marlin_permute_scales)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace)
from vllm.model_executor.layers.quantization.utils.quant_utils import gptq_pack
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
ModelWeightParameter)
from vllm.scalar_type import scalar_types

logger = init_logger(__name__)


class HQQMarlinConfig(QuantizationConfig):
"""Config class for HQQ Marlin"""

# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
4: scalar_types.uint4,
8: scalar_types.uint8,
}

def __init__(
self,
weight_bits: int,
group_size: int,
) -> None:
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.quant_type = self.TYPE_MAP[(weight_bits)]

def __repr__(self) -> str:
return (f"HQQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size})")

@classmethod
def get_name(cls) -> str:
return "hqq_marlin"

@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]

@classmethod
def get_min_capability(cls) -> int:
return 80

@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "HQQMarlinConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
return cls(weight_bits, group_size)

@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
#TODO
return None

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["HQQMarlinMethod"]:
if isinstance(layer, LinearBase):
return HQQMarlinMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


class HQQMarlinMethod(LinearMethodBase):
"""Linear method for HQQ Marlin.
"""

def __init__(
self,
quant_config: HQQMarlinConfig,
):
self.quant_config = quant_config

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
self.output_size_per_partition = sum(output_partition_sizes)

self.input_size_per_partition = input_size_per_partition

weight_loader = extra_weight_attrs.get("weight_loader")

scales_and_zp_size = (input_size_per_partition //
self.quant_config.group_size)

# Quantized weights
qweight = ModelWeightParameter(data=torch.empty(
self.output_size_per_partition,
input_size_per_partition,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

zeros = GroupQuantScaleParameter(data=torch.empty(
self.output_size_per_partition,
scales_and_zp_size,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

scales = GroupQuantScaleParameter(data=torch.empty(
self.output_size_per_partition,
scales_and_zp_size,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

layer.register_parameter("qweight", qweight)
layer.register_parameter("zeros", zeros)
layer.register_parameter("scales", scales)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
dev = layer.qweight.device
qweight_t = layer.qweight.transpose(1, 0)

gptq_w_q = gptq_pack(qweight_t, 4, self.input_size_per_partition,
self.output_size_per_partition)

sort_indices = torch.empty(0, dtype=torch.int, device=gptq_w_q.device)
marlin_w_q = ops.gptq_marlin_repack(
gptq_w_q,
sort_indices,
self.input_size_per_partition,
self.output_size_per_partition,
4,
).to(dev)
marlin_s = marlin_permute_scales(layer.scales.transpose(1, 0),
self.input_size_per_partition,
self.output_size_per_partition,
64).to(dev)
marlin_zp = marlin_permute_scales(layer.zeros.transpose(1, 0),
self.input_size_per_partition,
self.output_size_per_partition,
64).to(dev)

layer.g_idx = marlin_make_empty_g_idx(dev)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(dev)

layer.marlin_qweight = marlin_w_q
layer.marlin_zeros = marlin_zp
layer.marlin_scales = marlin_s

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
workspace = MarlinWorkspace(self.output_size_per_partition,
GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)

marlin_out = ops.gptq_marlin_gemm(
x,
layer.marlin_qweight,
layer.marlin_scales,
layer.marlin_zeros,
layer.g_idx,
layer.g_idx_sort_indices,
workspace.scratch,
scalar_types.uint4,
x.shape[0],
self.output_size_per_partition,
self.input_size_per_partition,
True, # is_k_full
True, # has_zp
False, # use 32-bit reduce
True, # use float zp
)

if bias is not None:
marlin_out.add_(bias)

return marlin_out
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ def apply_gptq_marlin_linear(
size_k=input_size_per_partition,
is_k_full=is_k_full,
has_zp=False,
use_fp32_reduce=use_fp32_reduce)
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)

if bias is not None:
output.add_(bias) # In-place add
Expand Down Expand Up @@ -340,7 +341,8 @@ def apply_awq_marlin_linear(
size_k=input_size_per_partition,
is_k_full=True,
has_zp=True,
use_fp32_reduce=use_fp32_reduce)
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)

if bias is not None:
output.add_(bias) # In-place add
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def get_quant_config(model_config: ModelConfig,
if model_config.quantization == "gguf":
return quant_cls.from_config({})

if model_config.quantization == "hqq_marlin":
# TODO don't hardcode params
return quant_cls.from_config({"bits": 4, "group_size": 64})

# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
None)
Expand Down
Loading