diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 7f9c687fa4cd6..0a2b1c2297c1e 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -1,12 +1,13 @@ import abc import operator from abc import abstractmethod -from typing import Callable, Iterable, List, Optional, Tuple +from typing import Callable, Dict, Iterable, List, Optional, Tuple import torch import torch._inductor.pattern_matcher as pm # TODO(luka) use vllm.utils once #10836 landed from compressed_tensors.quantization import FP8_DTYPE +from torch import fx from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass @@ -30,8 +31,7 @@ def empty_fp32(*args, **kwargs): # Returns the first auto_functionalized node with the given op (if it exists) -def find_auto_fn_maybe(nodes: Iterable[torch.fx.Node], - op) -> Optional[torch.fx.Node]: +def find_auto_fn_maybe(nodes: Iterable[fx.Node], op) -> Optional[fx.Node]: for node in nodes: if is_func(node, auto_functionalized) and node.args[0] == op: # noqa return node @@ -39,7 +39,7 @@ def find_auto_fn_maybe(nodes: Iterable[torch.fx.Node], # Returns the first auto_functionalized node with the given op -def find_auto_fn(nodes: Iterable[torch.fx.Node], op) -> torch.fx.Node: +def find_auto_fn(nodes: Iterable[fx.Node], op) -> fx.Node: node = find_auto_fn_maybe(nodes, op) assert node is not None, f"Could not find {op} in nodes {nodes}" return node @@ -47,8 +47,7 @@ def find_auto_fn(nodes: Iterable[torch.fx.Node], op) -> torch.fx.Node: # Returns the getitem node that extracts the idx-th element from node # (if it exists) -def find_getitem_maybe(node: torch.fx.Node, - idx: int) -> Optional[torch.fx.Node]: +def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]: for user in node.users: if is_func(user, operator.getitem) and user.args[1] == idx: return user @@ -56,7 +55,7 @@ def find_getitem_maybe(node: torch.fx.Node, # Returns the getitem node that extracts the idx-th element from node -def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node: +def find_getitem(node: fx.Node, idx: int) -> fx.Node: ret = find_getitem_maybe(node, idx) assert ret is not None, f"Could not find getitem {idx} in node {node}" return ret @@ -104,14 +103,14 @@ def process(self): raise NotImplementedError @property - def nodes(self) -> List[torch.fx.Node]: + def nodes(self) -> List[fx.Node]: return self.match.nodes @property - def graph(self) -> torch.fx.Graph: + def graph(self) -> fx.Graph: return self.match.graph - def find_auto_fn(self, op) -> torch.fx.Node: + def find_auto_fn(self, op) -> fx.Node: """ Find the first auto_functionalized node with the given op in the match. """ @@ -134,8 +133,8 @@ def inserting_after_match(self): return self.graph.inserting_after(last_node_in_match) - def insert_getitems(self, tuple_node: torch.fx.Node, - indices: Tuple[int, ...]) -> Tuple[torch.fx.Node, ...]: + def insert_getitems(self, tuple_node: fx.Node, + indices: Iterable[int]) -> Tuple[fx.Node, ...]: """ Insert operator.getitem nodes to extract elements from a tuple node. @@ -160,7 +159,6 @@ def insert_auto_fn(self, op, kwargs): RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default # Key: (fp8/int8, static/dynamic, per-tensor/per-token, symmetric/asymmetric) -# Value: the torch op QUANT_OPS = { (FP8_DTYPE, True, True, True): torch.ops._C.static_scaled_fp8_quant.default, @@ -183,6 +181,66 @@ def insert_auto_fn(self, op, kwargs): } +class QuantMultiOutputMatch(MultiOutputMatch): + + def __init__(self, match: pm.Match, quant_op, fused_op): + super().__init__(match) + self.QUANT_OP = quant_op + self.FUSED_OP = fused_op + + def insert_fused_node(self, fused_return_mapping: Dict[int, Tuple[fx.Node, + int]], + **kwargs): + """ + This utility function inserts an auto-functionalized node for FUSED_OP. + It also correctly sets its meta value and rebinds the users of the + unfused nodes to use the fused node instead. + + :param fused_return_mapping: A dictionary, mapping from getitem indices + of the fused node result to a tuple of the old node and a getitem index. + :param kwargs: kwargs that get directly forwarded to the auto_fn node + + Example: + If we want to replace this graph: + _, x1, x2 = auto_fn(op1) + _, y1, y2 = auto_fn(op2) + + with + _, x1, y2, x2 = auto_fn(FUSED_OP) + + we would call: + insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)} + + Note that the 0th element is None for auto-functionalized in-place ops. + Hence others appear 1-indexed. + """ + fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs) + indices = fused_return_mapping.keys() + getitem_nodes = self.insert_getitems(fused_node, indices) + + # Prepare the meta value, use a list so it's mutable + meta_val = [None] * (max(indices) + 1) + + # Iterate through elements of the tuple produced by fused_node + for idx, getitem_node in zip(indices, getitem_nodes): + old_node, old_idx = fused_return_mapping[idx] + + # If the old value was never used, the old_getitem might not exist + old_getitem = find_getitem_maybe(old_node, old_idx) + if old_getitem is not None: + # Rebind the users of match getitem nodes to use the new nodes. + # The old nodes will be removed by DCE at the end of the pass. + old_getitem.replace_all_uses_with(getitem_node) + getitem_node.meta["val"] = old_getitem.meta["val"] + + # Extract the appropriate meta value + # It is present even if the getitem node does not exist + meta_val[idx] = old_node.meta["val"][old_idx] + + # Fix the meta value on the new fused node + fused_node.meta["val"] = tuple(meta_val) + + class RMSNormQuantPattern: def __init__(self, @@ -212,13 +270,6 @@ def __init__(self, f" for quant scheme {keystr()})") self.FUSED_OP = FUSED_OPS[key2] - class Match(MultiOutputMatch): - - def __init__(self, match: pm.Match, quant_op, fused_op): - super().__init__(match) - self.QUANT_OP = quant_op - self.FUSED_OP = fused_op - class RMSNormStaticQuantPattern(RMSNormQuantPattern): @@ -339,7 +390,7 @@ def replacement(result: torch.Tensor, input: torch.Tensor, extra_check=lambda m: record_match( self.Match(m, self.QUANT_OP, self.FUSED_OP))) - class Match(RMSNormQuantPattern.Match): + class Match(QuantMultiOutputMatch): def process(self): # Find the nodes in the match that we need to rebind @@ -358,26 +409,14 @@ def process(self): # result_node_new = at[1] # residual_node_new = at[2] with self.inserting_after_match(): + # Missing epsilon, scalars cannot be inputs to the pattern kwargs = self.match.kwargs.copy() - # Scalars cannot be inputs to the pattern - kwargs["epsilon"] = rms_node.kwargs["epsilon"] - - # TODO simplify - fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs) - getitem_nodes = self.insert_getitems(fused_node, (1, 2)) - result_node_new, residual_node_new = getitem_nodes - - # Rebind the users of match getitem nodes to use the new nodes. - # The old nodes will be removed by DCE at the end of the pass. - find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) - find_getitem(quant_node, 1).replace_all_uses_with(result_node_new) - - # Finally, fix meta["val"] for de-functionalization. - # See MultiOutputMatch.process for more details. - rms_tup, quant_tup = rms_node.meta["val"], quant_node.meta["val"] - # Result of fused node is (None, result, residual) - fused_node.meta["val"] = (None, quant_tup[1], rms_tup[2]) + # 0 is always None + fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)} + self.insert_fused_node(fused_return_mapping, + epsilon=rms_node.kwargs["epsilon"], + **kwargs) class RMSNormDynamicQuantPattern(RMSNormQuantPattern): @@ -446,7 +485,7 @@ def replacement(result: torch.Tensor, result_rms: torch.Tensor, extra_check=lambda m: record_match( self.Match(m, self.QUANT_OP, self.FUSED_OP))) - class Match(RMSNormQuantPattern.Match): + class Match(QuantMultiOutputMatch): def process(self): # Find the nodes in the match that we need to rebind @@ -465,28 +504,17 @@ def process(self): # result_node_new = at[1] # scale_node_new = at[2] with self.inserting_after_match(): + # Missing epsilon, scalars cannot be inputs to the pattern kwargs = self.match.kwargs.copy() - - # Scalars cannot be inputs to the pattern - kwargs["epsilon"] = rms_node.kwargs["epsilon"] - kwargs["scale_ub"] = None # not used but required - kwargs["residual"] = None # not used but required del kwargs["result_rms"] # not used in the fused op - # TODO simplify - fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs=kwargs) - getitem_nodes = self.insert_getitems(fused_node, (1, 2)) - result_node_new, scale_node_new = getitem_nodes - - # Rebind the users of match getitem nodes to use the new nodes. - # The old nodes will be removed by DCE at the end of the pass. - find_getitem(quant_node, 1).replace_all_uses_with(result_node_new) - find_getitem(quant_node, 2).replace_all_uses_with(scale_node_new) - - # Finally, fix meta["val"] for de-functionalization. - # See MultiOutputMatch.process for more details. - # Result of fused node is (None, result, scale) - fused_node.meta["val"] = quant_node.meta["val"] + fused_return_mapping = {1: (quant_node, 1), 2: (quant_node, 2)} + self.insert_fused_node( + fused_return_mapping, + epsilon=rms_node.kwargs["epsilon"], + scale_ub=None, # not used but required + residual=None, # not used but required + **kwargs) class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): @@ -555,7 +583,7 @@ def replacement(result: torch.Tensor, input: torch.Tensor, extra_check=lambda m: record_match( self.Match(m, self.QUANT_OP, self.FUSED_OP))) - class Match(RMSNormQuantPattern.Match): + class Match(QuantMultiOutputMatch): def process(self): # Find the nodes in the match that we need to rebind @@ -575,28 +603,19 @@ def process(self): # scale_node_new = at[2] # residual_node_new = at[3] with self.inserting_after_match(): + # Missing epsilon, scalars cannot be inputs to the pattern kwargs = self.match.kwargs.copy() - # Scalars cannot be inputs to the pattern - kwargs["epsilon"] = rms_node.kwargs["epsilon"] - kwargs["scale_ub"] = None # not used but required - - fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs=kwargs) - getitem_ns = self.insert_getitems(fused_node, (1, 2, 3)) - result_node_new, scale_node_new, residual_node_new = getitem_ns - - # Rebind the users of match getitem nodes to use the new nodes. - # The old nodes will be removed by DCE at the end of the pass. - find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) - find_getitem(quant_node, 1).replace_all_uses_with(result_node_new) - find_getitem(quant_node, 2).replace_all_uses_with(scale_node_new) - - # Finally, fix meta["val"] for de-functionalization. - # See MultiOutputMatch.process for more details. - rms_tup, quant_tup = rms_node.meta["val"], quant_node.meta["val"] - # Result of fused node is (None, result, scale, residual) - fused_node.meta["val"] = (None, quant_tup[1], quant_tup[2], - rms_tup[2]) + fused_return_mapping = { + 1: (quant_node, 1), # result + 2: (quant_node, 2), # scale + 3: (rms_node, 2), # residual + } + self.insert_fused_node( + fused_return_mapping, + epsilon=rms_node.kwargs["epsilon"], + scale_ub=None, # not used but required + **kwargs) class FusionPass(VllmInductorPass): @@ -671,7 +690,7 @@ def record_match(self, match: MultiOutputMatch) -> bool: # Return False to prevent automatic replacement. return False - def process_matches(self, graph: torch.fx.Graph): + def process_matches(self, graph: fx.Graph): """ Manually process multi-output matches and replace them with fused nodes. See MultiOutputMatch for more details. @@ -684,7 +703,7 @@ def process_matches(self, graph: torch.fx.Graph): assert all(node not in graph.nodes for match in self.matches for node in match.match.nodes) - def __call__(self, graph: torch.fx.Graph): + def __call__(self, graph: fx.Graph): self.begin() self.dump_graph(graph, "before_fusion")