From 249902a9c71d445601f5bee41ed79d7c371b3944 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 23 May 2024 13:05:27 -0700 Subject: [PATCH 1/2] Fix fp8_gemm on H100 --- auto_fp8/quantize.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/auto_fp8/quantize.py b/auto_fp8/quantize.py index 5babe32..d68ca58 100644 --- a/auto_fp8/quantize.py +++ b/auto_fp8/quantize.py @@ -65,14 +65,23 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype): torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) ) if native_fp8_support: + need_reshape = A.dim() == 3 + if need_reshape: + batch_size = A.shape[0] + A_input = A.reshape(-1, A.shape[-1]) + else: + batch_size = None + A_input = A output, _ = torch._scaled_mm( - A, + A_input, B.t(), out_dtype=out_dtype, scale_a=A_scale, scale_b=B_scale, bias=bias, ) + if need_reshape: + output = output.reshape((batch_size, *output.shape)) else: output = torch.nn.functional.linear( A.to(out_dtype) * A_scale, From 2c70d7a532943d11681a4dd21795d56eebc538d0 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 23 May 2024 13:07:02 -0700 Subject: [PATCH 2/2] Update auto_fp8/quantize.py --- auto_fp8/quantize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_fp8/quantize.py b/auto_fp8/quantize.py index d68ca58..ef7ff4d 100644 --- a/auto_fp8/quantize.py +++ b/auto_fp8/quantize.py @@ -81,7 +81,7 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype): bias=bias, ) if need_reshape: - output = output.reshape((batch_size, *output.shape)) + output = output.reshape((batch_size, output.shape[0] // batch_size, output.shape[1])) else: output = torch.nn.functional.linear( A.to(out_dtype) * A_scale,