Skip to content

Commit

Permalink
Merge pull request #8 from comaniac/patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin authored May 24, 2024
2 parents d69a57f + 2c70d7a commit c4a9594
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion auto_fp8/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,23 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
)
if native_fp8_support:
need_reshape = A.dim() == 3
if need_reshape:
batch_size = A.shape[0]
A_input = A.reshape(-1, A.shape[-1])
else:
batch_size = None
A_input = A
output, _ = torch._scaled_mm(
A,
A_input,
B.t(),
out_dtype=out_dtype,
scale_a=A_scale,
scale_b=B_scale,
bias=bias,
)
if need_reshape:
output = output.reshape((batch_size, output.shape[0] // batch_size, output.shape[1]))
else:
output = torch.nn.functional.linear(
A.to(out_dtype) * A_scale,
Expand Down

0 comments on commit c4a9594

Please sign in to comment.