Skip to content

Commit

Permalink
updated
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-redhat committed Nov 15, 2024
1 parent 4c61b19 commit 86716f8
Showing 1 changed file with 0 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,30 +102,6 @@ def apply_weights(self,

return out

def quantize_with_max_scale(
weight: torch.Tensor, weight_scale: torch.Tensor,
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for quanitzation.
max_w_scale = weight_scale.max()

# QKV / MLP is fused in the on disk checkpoint if any of the
# weight scales are still set to the default since we initialize
# N weight scales for N shards but we only load 1 weight scale
# from disk in this case. Skip requantization in this case (since)
# we already are quantized with the single scale.
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
unfused_module_in_checkpoint = (weight_scale[-1] > torch.finfo(
torch.float8_e4m3fn).min)
q_weight = torch.empty_like(weight).to(torch.float8_e4m3fn)
# If unfused checkpoint, need quantize with the single scale.
if unfused_module_in_checkpoint:
start = 0
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
q_weight[start:end, :], _ = ops.scaled_fp8_quant(
weight[start:end, :], max_w_scale)
start = end
return max_w_scale, q_weight

def check_24(tensor):
new_tensor = tensor.view(-1, 4)
Expand Down

0 comments on commit 86716f8

Please sign in to comment.