Skip to content

Commit

Permalink
GPTQ Support [Cont.] (#481)
Browse files Browse the repository at this point in the history
This PR enables loading GPTQ quantized models and running weight-only
quantized inference on HPU. For a previous discussion, see #421.
  • Loading branch information
maktukmak authored Jan 30, 2025
1 parent 1710059 commit ce21aad
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 6 deletions.
7 changes: 6 additions & 1 deletion .jenkins/test_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,9 @@ stages:
command: VLLM_SKIP_WARMUP=true pytest -v tests/lora/test_multilora_hpu.py::test_llama_multilora_1x
- name: test_long_context
flavor: g2
command: VLLM_SKIP_WARMUP=true pytest -v tests/lora/test_long_context_hpu.py::test_quality
command: VLLM_SKIP_WARMUP=true pytest -v tests/lora/test_long_context_hpu.py::test_quality
- name: tests_int4_quantization
steps:
- name: test_gptq
flavor: g2
command: VLLM_SKIP_WARMUP=true pytest -v tests/quantization/test_gptq.py::test_gptq
1 change: 1 addition & 0 deletions README_GAUDI.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ $ python setup.py develop
| Inference with torch.compile (experimental) | vLLM HPU backend experimentally supports inference with torch.compile. | [vLLM HPU backend execution modes](https://docs.vllm.ai/en/stable/getting_started/gaudi-installation.html#execution-modes) |
| Attention with Linear Biases (ALiBi) | vLLM HPU backend supports models utilizing Attention with Linear Biases (ALiBi) such as mpt-7b. | [vLLM supported models](https://docs.vllm.ai/en/latest/models/supported_models.html) |
| INC quantization | vLLM HPU backend supports FP8 model and KV cache quantization and calibration with Intel Neural Compressor (INC). | [Documentation](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html) |
| AutoGPTQ quantization | vLLM HPU backend supports the inference with models quantized using AutoGPTQ library. | [Library](https://github.com/AutoGPTQ/AutoGPTQ) |
| LoRA/MultiLoRA support | vLLM HPU backend includes support for LoRA and MultiLoRA on supported models. | [Documentation](https://docs.vllm.ai/en/stable/models/lora.html)<br>[Example](https://docs.vllm.ai/en/stable/getting_started/examples/multilora_inference.html)<br>[vLLM supported models](https://docs.vllm.ai/en/latest/models/supported_models.html) |
| Multi-step scheduling support | vLLM HPU backend includes multi-step scheduling support for host overhead reduction, configurable by standard `--num-scheduler-seqs` parameter. | [Feature RFC](https://github.com/vllm-project/vllm/issues/6854) |
| Automatic prefix caching (experimental) | vLLM HPU backend includes automatic prefix caching (APC) support for more efficient prefills, configurable by standard `--enable-prefix-caching` parameter. | [Documentation](https://docs.vllm.ai/en/stable/automatic_prefix_caching/apc.html)<br>[Details](https://docs.vllm.ai/en/stable/automatic_prefix_caching/details.html) |
Expand Down
28 changes: 28 additions & 0 deletions tests/quantization/test_gptq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Test model set-up and inference for quantized HF models supported
on the HPU backend using AutoGPTQ.
Validating the configuration and printing results for manual checking.
Run `pytest tests/quantization/test_gptq.py`.
"""

import pytest

from vllm.platforms import current_platform

MODELS = [
"TheBloke/Llama-2-7B-Chat-GPTQ",
]
DTYPE = ["bfloat16"]


@pytest.mark.skipif(not current_platform.is_hpu(),
reason="only supports Intel HPU backend.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", DTYPE)
def test_gptq(vllm_runner, model, dtype):
with vllm_runner(model, dtype=dtype, quantization='gptq_hpu') as llm:
output = llm.generate_greedy(["The capital of France is"],
max_tokens=32)
assert output
print(output)
1 change: 0 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)


# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool,
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
"HQQMarlinMethod", "QuarkLinearMethod"
"TPUInt8LinearMethod", "GPTQLinearMethod", "GPTQHPULinearMethod",
"FBGEMMFp8LinearMethod", "ModelOptFp8LinearMethod", "IPEXAWQLinearMethod",
"IPEXGPTQLinearMethod", "HQQMarlinMethod", "QuarkLinearMethod"
]


Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"gptq_marlin",
"awq_marlin",
"gptq",
"gptq_hpu",
"compressed-tensors",
"bitsandbytes",
"qqq",
Expand Down Expand Up @@ -75,6 +76,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
raise ValueError(f"Invalid quantization method: {quantization}")

# lazy import to avoid triggering `torch.compile` too early
from vllm_hpu_extension.gptq_hpu import GPTQHPUConfig

from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig

from .aqlm import AQLMConfig
Expand Down Expand Up @@ -116,6 +119,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig,
"gptq": GPTQConfig,
"gptq_hpu": GPTQHPUConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
Expand Down
2 changes: 1 addition & 1 deletion vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class HpuPlatform(Platform):
dispatch_key: str = "HPU"
ray_device_key: str = "HPU"
device_control_env_var: str = "HABANA_VISIBLE_MODULES"
supported_quantization: list[str] = ["fp8", "inc"]
supported_quantization: list[str] = ["fp8", "inc", "gptq_hpu"]

@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
Expand Down

0 comments on commit ce21aad

Please sign in to comment.