Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recalc scales from user #774

Open
wants to merge 17 commits into
base: habana_main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz,
HolyFalafel marked this conversation as resolved.
Show resolved Hide resolved
apply_fp8_linear, cutlass_fp8_supported, get_gaudi2_scale_factor, is_gaudi2, normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter,
Expand Down Expand Up @@ -82,7 +82,10 @@ def process_weights_after_loading(self, layer) -> None:

# INPUT SCALE
if self.is_static_input_scheme and hasattr(layer, 'input_scale'):
layer.input_scale = Parameter(layer.input_scale.max(),
input_scale = layer.input_scale.max()
if is_gaudi2():
input_scale = input_scale * get_gaudi2_scale_factor()
layer.input_scale = Parameter(input_scale,
requires_grad=False)
else:
layer.input_scale = None
Expand Down
13 changes: 9 additions & 4 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,21 @@ def convert_to_channelwise(

return weight_scale_channel

def is_gaudi2():
linoybu marked this conversation as resolved.
Show resolved Hide resolved
return current_platform.is_hpu() and htexp._get_device_type(
) == htexp.synDeviceType.synDeviceGaudi2

def get_gaudi2_scale_factor():
return (torch.finfo(torch.float8_e4m3fn).max /
torch.finfo(torch.float8_e4m3fnuz).max)

def requantize_with_max_scale(
weight: torch.Tensor, weight_scale: torch.Tensor,
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max()
if current_platform.is_hpu() and htexp._get_device_type(
) == htexp.synDeviceType.synDeviceGaudi2:
max_w_scale = max_w_scale * (torch.finfo(torch.float8_e4m3fn).max /
torch.finfo(torch.float8_e4m3fnuz).max)
if is_gaudi2():
max_w_scale = max_w_scale * get_gaudi2_scale_factor()
# QKV / MLP is fused in the on disk checkpoint if any of the
# weight scales are still set to the default since we initialize
# N weight scales for N shards but we only load 1 weight scale
Expand Down
Loading