From 5d0cbafd033b85a1a031bce5909ce08481523efd Mon Sep 17 00:00:00 2001 From: luka Date: Wed, 4 Dec 2024 22:37:46 +0000 Subject: [PATCH] Fix fusion and tests to use dynamic per-token Signed-off-by: luka --- tests/compile/test_fusion.py | 31 +++++++++++++++++++++---------- vllm/_custom_ops.py | 1 - vllm/compilation/fusion.py | 12 +++++++----- 3 files changed, 28 insertions(+), 16 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index fa1765d6ad84a..6dc989f0d634c 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -31,14 +31,22 @@ def __init__(self, hidden_size: int, eps: float, static: bool, *args, ] def forward(self, x): - resid = torch.relu(x) + resid = torch.sqrt(x) y = self.norm[0](x) - x2 = apply_fp8_linear(y, self.w[0], self.wscale[0], self.scale[0]) + x2 = apply_fp8_linear(y, + self.w[0], + self.wscale[0], + self.scale[0], + use_per_token_if_dynamic=True) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - x3 = apply_fp8_linear(y2, self.w[1], self.wscale[1], self.scale[1]) + x3 = apply_fp8_linear(y2, + self.w[1], + self.wscale[1], + self.scale[1], + use_per_token_if_dynamic=True) y3, resid = self.norm[2](x3, resid) # use resid here return y3 @@ -75,12 +83,15 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static): model2 = torch.compile(model, backend=backend) result2 = model2(x) - # Check that it gives the same answer, higher tol for dynamic - ATOL, RTOL = (1e-3, 1e-3) if static else (2e-2, 2e-2) - torch.testing.assert_close(result.to(dtype=torch.float32), - result2.to(dtype=torch.float32), - atol=ATOL, - rtol=RTOL) + # Higher tol for dynamic, even higher for bfloat16 + if static: + ATOL, RTOL = (1e-3, 1e-3) + elif dtype == torch.float16: + ATOL, RTOL = (2e-3, 2e-3) + else: + ATOL, RTOL = (1e-2, 1e-2) + + torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) # Check substitution worked pre_nodes = backend.graph_pre_pass.nodes @@ -93,7 +104,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static): else: rms_quant = torch.ops._C.rms_norm_dynamic_per_token_quant.default add_rms_quant = torch.ops._C.rms_norm_dynamic_per_token_quant.default # noqa: E501 - fp8_quant = torch.ops._C.dynamic_scaled_fp8_quant.default + fp8_quant = torch.ops._C.dynamic_per_token_scaled_fp8_quant.default # In pre-nodes, fp8 quant should be present and fused kernels should not assert find_auto_fn_maybe(pre_nodes, rms_quant) is None diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index bed3dad57c580..3808fb9a87e56 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -269,7 +269,6 @@ def rms_norm_dynamic_per_token_quant( return output, scales -# TODO is this necessary? @register_fake("_C::rms_norm_dynamic_per_token_quant") def _rms_norm_dynamic_per_token_quant_fake( output: torch.Tensor, diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 907f9ad2c8a7b..823e66867f28a 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -163,7 +163,7 @@ def insert_auto_fn(self, op, kwargs): RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default QUANT_STATIC_FP8_OP = torch.ops._C.static_scaled_fp8_quant.default -QUANT_DYNAMIC_FP8_OP = torch.ops._C.dynamic_scaled_fp8_quant.default +QUANT_DYNAMIC_FP8_OP = torch.ops._C.dynamic_per_token_scaled_fp8_quant.default class RMSNormQuantPattern: @@ -329,7 +329,8 @@ def pattern(result: torch.Tensor, result_rms: torch.Tensor, at2 = auto_functionalized(QUANT_DYNAMIC_FP8_OP, result=result, input=at1[1], - scale=scale) + scale=scale, + scale_ub=None) # result, scale return at2[1], at2[2] @@ -427,7 +428,8 @@ def pattern(result: torch.Tensor, input: torch.Tensor, at1 = auto_functionalized(QUANT_DYNAMIC_FP8_OP, result=result, input=at[1], - scale=scale) + scale=scale, + scale_ub=None) # result, residual, scale return at1[1], at[2], at1[2] @@ -559,12 +561,12 @@ def __init__(self, config: CompilationConfig.PassConfig): FusedAddRMSNormStaticFP8QuantPattern(epsilon).register( self.patterns, self.record_match) - # Fuse rms_norm + dynamic_scaled_fp8_quant into + # Fuse rms_norm + dynamic_per_token_scaled_fp8_quant into # rms_norm_dynamic_per_token_quant RMSNormDynamicFP8QuantPattern(epsilon).register( self.patterns, self.record_match) - # Fuse fused_add_rms_norm + dynamic_scaled_fp8_quant into + # Fuse fused_add_rms_norm + dynamic_per_token_scaled_fp8_quant into # rms_norm_dynamic_per_token_quant FusedAddRMSNormDynamicFP8QuantPattern(epsilon).register( self.patterns, self.record_match)