diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 2ad3504bd3bfe..2ce5c730ef62a 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -14,12 +14,52 @@ from .backend import TestBackend +# TODO temp +@torch.library.custom_op("_C::rms_norm_dynamic_fp8_quant", + mutates_args=("result", "scale")) +def rms_norm_dynamic_fp8_quant(result: torch.Tensor, input: torch.Tensor, + weight: torch.Tensor, scale: torch.Tensor, + epsilon: float) -> None: + result_rms = torch.empty_like(input) + torch.ops._C.rms_norm(result_rms, input, weight, epsilon) + torch.ops._C.dynamic_scaled_fp8_quant(result, result_rms, scale) + + +@torch.library.register_fake("_C::rms_norm_dynamic_fp8_quant") +def _(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor, epsilon: float): + return None + + +@torch.library.custom_op("_C::fused_add_rms_norm_dynamic_fp8_quant", + mutates_args=("result", "residual", "scale")) +def fused_add_rms_norm_dynamic_fp8_quant(result: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + epsilon: float) -> None: + torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) + torch.ops._C.dynamic_scaled_fp8_quant(result, input, scale) + + +@torch.library.register_fake("_C::rms_norm_dynamic_fp8_quant") +def _(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor, epsilon: float): + return None + + class TestModel(torch.nn.Module): - def __init__(self, hidden_size: int, eps: float, *args, **kwargs): + def __init__(self, hidden_size: int, eps: float, static: bool, *args, + **kwargs): super().__init__(*args, **kwargs) self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] - self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(4)] + self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] + if static: + self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] + else: + self.scale = [None for _ in range(2)] self.w = [ torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() for _ in range(2) @@ -29,11 +69,11 @@ def forward(self, x): resid = torch.relu(x) y = self.norm[0](x) - x2 = apply_fp8_linear(y, self.w[0], self.scale[0], self.scale[1]) + x2 = apply_fp8_linear(y, self.w[0], self.wscale[0], self.scale[0]) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - x3 = apply_fp8_linear(y2, self.w[1], self.scale[2], self.scale[3]) + x3 = apply_fp8_linear(y2, self.w[1], self.wscale[1], self.scale[1]) y3, resid = self.norm[2](x3, resid) # use resid here return y3 @@ -48,15 +88,16 @@ def forward(self, x): @pytest.mark.parametrize("hidden_size", [64, 3392, 4096]) @pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) +@pytest.mark.parametrize("static", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA") -def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps): +def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static): torch.set_default_device("cuda") torch.set_default_dtype(torch.float16) # Reshape pass is needed for the fusion pass to work backend = TestBackend(reshape_pass, fusion_pass) - model = TestModel(hidden_size, eps) + model = TestModel(hidden_size, eps, static) # First dimension dynamic x = torch.rand(num_tokens, hidden_size) @@ -74,9 +115,14 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps): pre_nodes = backend.graph_pre_pass.nodes post_nodes = backend.graph_post_pass.nodes - rms_quant = torch.ops._C.rms_norm_static_fp8_quant.default - add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default - fp8_quant = torch.ops._C.static_scaled_fp8_quant.default + if static: + rms_quant = torch.ops._C.rms_norm_static_fp8_quant.default + add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default # noqa: E501 + fp8_quant = torch.ops._C.static_scaled_fp8_quant.default + else: + rms_quant = torch.ops._C.rms_norm_dynamic_fp8_quant.default + add_rms_quant = torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default # noqa: E501 + fp8_quant = torch.ops._C.dynamic_scaled_fp8_quant.default # In pre-nodes, fp8 quant should be present and fused kernels should not assert find_auto_fn_maybe(pre_nodes, rms_quant) is None diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index fe18f297c79e4..dfdb86af04a94 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -163,6 +163,7 @@ def insert_auto_fn(self, op, kwargs): RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default QUANT_STATIC_FP8_OP = torch.ops._C.static_scaled_fp8_quant.default +QUANT_DYNAMIC_FP8_OP = torch.ops._C.dynamic_scaled_fp8_quant.default class RMSNormQuantPattern: @@ -312,6 +313,198 @@ def process(self): fused_node.meta["val"] = (None, quant_tup[1], rms_tup[2]) +class RMSNormDynamicFP8QuantPattern(RMSNormQuantPattern): + + def register(self, pm_pass: PatternMatcherPass, + record_match: Callable[[MultiOutputMatch], bool]): + + def pattern(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(RMS_OP, + result=result_rms, + input=input, + weight=weight, + epsilon=self.epsilon) + at2 = auto_functionalized(QUANT_DYNAMIC_FP8_OP, + result=result, + input=at1[1], + scale=scale) + + # result, scale + return at2[1], at2[2] + + def replacement(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized( + torch.ops._C.rms_norm_static_fp8_quant.default, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon) + + # result, scale + return at[1], at[2] + + inputs = [ + empty_fp8(5, 4), # result + empty_bf16(5, 4), # result_rms + empty_bf16(5, 4), # input + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + extra_check=lambda m: record_match(self.Match(m))) + + class Match(MultiOutputMatch): + + def process(self): + # Find the nodes in the match that we need to rebind + rms_node = self.find_auto_fn(RMS_OP) + quant_node = self.find_auto_fn(QUANT_DYNAMIC_FP8_OP) + + assert len(rms_node.users) == 1 + assert len(quant_node.users) == 2 + + # First, insert a new auto_functionalized node for the fused op, + # as well as getitem nodes to extract the result and scale. + # The auto_fn node returns a tuple of (None, result, scale). + # + # The resulting graph looks like this: + # at = auto_functionalized(torch.ops._C.rms_norm_static_fp8_quant.default, ...) # noqa + # result_node_new = at[1] + # scale_node_new = at[2] + with self.inserting_after_match(): + kwargs = self.match.kwargs.copy() + + # Scalars cannot be inputs to the pattern + kwargs["epsilon"] = rms_node.kwargs["epsilon"] + del kwargs["result_rms"] # not used in the fused op + + fused_node = self.insert_auto_fn( + torch.ops._C.rms_norm_dynamic_fp8_quant.default, + kwargs=kwargs) + + getitem_nodes = self.insert_getitems(fused_node, (1, 2)) + result_node_new, scale_node_new = getitem_nodes + + # Rebind the users of match getitem nodes to use the new nodes. + # The old nodes will be removed by DCE at the end of the pass. + find_getitem(quant_node, 1).replace_all_uses_with(result_node_new) + find_getitem(quant_node, 2).replace_all_uses_with(scale_node_new) + + # Finally, fix meta["val"] for de-functionalization. + # See MultiOutputMatch.process for more details. + # Result of fused node is (None, result, scale) + fused_node.meta["val"] = quant_node.meta["val"] + + +class FusedAddRMSNormDynamicFP8QuantPattern(RMSNormQuantPattern): + + def register(self, pm_pass: PatternMatcherPass, + record_match: Callable[[MultiOutputMatch], bool]): + + def pattern(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon) + at1 = auto_functionalized(QUANT_DYNAMIC_FP8_OP, + result=result, + input=at[1], + scale=scale) + + # result, residual, scale + return at1[1], at[2], at1[2] + + def replacement(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized( + torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default, + result=result, + input=input, + residual=residual, + weight=weight, + scale=scale, + epsilon=self.epsilon) + + # result, residual, scale + return at[1], at[2], at[3] # TODO confirm signature + + inputs = [ + empty_fp8(5, 4), # result + empty_bf16(5, 4), # input + empty_bf16(5, 4), # residual + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + extra_check=lambda m: record_match(self.Match(m))) + + class Match(MultiOutputMatch): + + def process(self): + # Find the nodes in the match that we need to rebind + rms_node = self.find_auto_fn(RMS_ADD_OP) + quant_node = self.find_auto_fn(QUANT_DYNAMIC_FP8_OP) + + assert len(rms_node.users) == 2 + assert len(quant_node.users) == 2 + + # First, insert a new auto_functionalized node for the fused op, + # as well as getitem nodes to extract result, scale, and residual. + # The auto_fn node returns a tuple (None, result, scale, residual). + # + # The resulting graph looks like this: + # at = auto_functionalized(torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default, ...) # noqa + # result_node_new = at[1] + # residual_node_new = at[2] + # scale_node_new = at[3] + with self.inserting_after_match(): + kwargs = self.match.kwargs.copy() + + # Scalars cannot be inputs to the pattern + kwargs["epsilon"] = rms_node.kwargs["epsilon"] + + fused_node = self.insert_auto_fn( + torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default, + kwargs=kwargs) + + getitem_ns = self.insert_getitems(fused_node, (1, 2, 3)) + result_node_new, residual_node_new, scale_node_new = getitem_ns + + # Rebind the users of match getitem nodes to use the new nodes. + # The old nodes will be removed by DCE at the end of the pass. + find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) + find_getitem(quant_node, 1).replace_all_uses_with(result_node_new) + find_getitem(quant_node, 2).replace_all_uses_with(scale_node_new) + + # Finally, fix meta["val"] for de-functionalization. + # See MultiOutputMatch.process for more details. + rms_tup, quant_tup = rms_node.meta["val"], quant_node.meta["val"] + # Result of fused node is (None, result, scale, residual) + fused_node.meta["val"] = (None, quant_tup[1], quant_tup[2], + rms_tup[2]) + + class FusionPass(InductorPass): """ This pass fuses a pre-defined set of custom ops into fused ops. @@ -360,6 +553,16 @@ def __init__(self, config: CompilationConfig): FusedAddRMSNormStaticFP8QuantPattern(epsilon).register( self.patterns, self.record_match) + # Fuse rms_norm + dynamic_scaled_fp8_quant into + # rms_norm_dynamic_fp8_quant + RMSNormDynamicFP8QuantPattern(epsilon).register( + self.patterns, self.record_match) + + # Fuse fused_add_rms_norm + dynamic_scaled_fp8_quant into + # fused_add_rms_norm_dynamic_fp8_quant + FusedAddRMSNormDynamicFP8QuantPattern(epsilon).register( + self.patterns, self.record_match) + # WARNING: This is a hack to clear the pattern matcher cache # and allow multiple values of epsilon. torch._inductor.pattern_matcher._seen_patterns.clear()