Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recalc scales from user #774

Open
wants to merge 17 commits into
base: habana_main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from compressed_tensors.quantization import QuantizationStrategy
from torch.nn import Parameter

from vllm_hpu_extension.ops import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Expand Down Expand Up @@ -82,8 +82,10 @@ def process_weights_after_loading(self, layer) -> None:

# INPUT SCALE
if self.is_static_input_scheme and hasattr(layer, 'input_scale'):
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
input_scale = layer.input_scale.max()
if is_hpu_gaudi2():
input_scale = input_scale * get_hpu_gaudi2_scale_factor()
layer.input_scale = Parameter(input_scale, requires_grad=False)
else:
layer.input_scale = None

Expand Down
8 changes: 3 additions & 5 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from vllm import _custom_ops as ops
from vllm_hpu_extension.ops import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
from vllm.platforms import current_platform

# Input scaling factors are no longer optional in _scaled_mm starting
Expand Down Expand Up @@ -73,16 +74,13 @@ def convert_to_channelwise(

return weight_scale_channel


def requantize_with_max_scale(
weight: torch.Tensor, weight_scale: torch.Tensor,
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max()
if current_platform.is_hpu() and htexp._get_device_type(
) == htexp.synDeviceType.synDeviceGaudi2:
max_w_scale = max_w_scale * (torch.finfo(torch.float8_e4m3fn).max /
torch.finfo(torch.float8_e4m3fnuz).max)
if is_hpu_gaudi2():
max_w_scale = max_w_scale * get_hpu_gaudi2_scale_factor()
# QKV / MLP is fused in the on disk checkpoint if any of the
# weight scales are still set to the default since we initialize
# N weight scales for N shards but we only load 1 weight scale
Expand Down
Loading