Skip to content

Commit

Permalink
Improved comments, move utils to MultiOutputMatch
Browse files Browse the repository at this point in the history
Signed-off-by: luka <luka@neuralmagic.com>
  • Loading branch information
ProExpertProg committed Nov 8, 2024
1 parent fa8e376 commit 414c451
Showing 1 changed file with 78 additions and 47 deletions.
125 changes: 78 additions & 47 deletions vllm/compilation/fusion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import abc
import operator
from abc import abstractmethod
from typing import Callable, Iterable, List, Optional, Tuple

import torch
Expand Down Expand Up @@ -62,8 +64,76 @@ def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node:
return ret


class MultiOutputMatch:
pass
class MultiOutputMatch(abc.ABC):

def __init__(self, match: pm.Match):
self.match = match

@property
def nodes(self) -> List[torch.fx.Node]:
return self.match.nodes

@abstractmethod
def process(self, graph: torch.fx.Graph):
"""
Process a multi-output match and manually insert the replacement.
This method should:
1. Insert the replacement nodes after the last node in the match.
2. Rebind the users of nodes in the match to use the new nodes.
3. Set meta["val"] for de-functionalization.
The result of an auto-functionalized node is a tuple of tensors.
The first element is the return value of the function, usually None.
The remaining elements are the mutated args of the function.
All auto-functionalized nodes must contain a proper meta["val"],
as it is used by de-functionalization. meta["val"] has to contain the
value of the node (tuple of tensors) that would be returned by the
functionalized node during tracing.
Existing nodes in the graph all have this property set, but we have
to set it manually for new nodes we insert.
Example:
# op schema: foo(a: Tensor!, b: Tensor, c: Tensor!) -> None
at = auto_functionalized(torch.ops._C.foo.default, a, b, c)
# at.meta["val"] = (None, a, c)
"""
raise NotImplementedError

def inserting_after_match(self, graph: torch.fx.Graph):
"""
Insert nodes after the last node in the match.
This is done to avoid use-before-definition errors after inserting
replacement nodes.
"""

# match.nodes is not guaranteed to be sorted.
# Find the last node in the match.
for last_node_in_match in reversed(graph.nodes):
if last_node_in_match in self.match.nodes:
break
else:
raise ValueError("No nodes in graph")

return graph.inserting_after(last_node_in_match)

def insert_getitems(self, graph: torch.fx.Graph, tuple_node: torch.fx.Node,
indices: Tuple[int, ...]):
"""
Insert operator.getitem nodes to extract elements from a tuple node.
:param graph: The graph to insert nodes into.
:param tuple_node: The tuple node to extract elements from.
:param indices: The indices of the elements to extract.
:return: Tuple of the new getitem nodes, corresponding to the indices.
"""
with graph.inserting_after(tuple_node):
return [
graph.call_function(operator.getitem, (tuple_node, idx))
for idx in indices
]


class RMSNormQuantPattern:
Expand Down Expand Up @@ -170,15 +240,10 @@ def replacement(result: torch.Tensor, input: torch.Tensor,
inputs,
pm.fwd_only,
pm_pass,
extra_check=lambda m: record_match(self.Match(m, self)))
extra_check=lambda m: record_match(self.Match(m)))

class Match(MultiOutputMatch):

def __init__(self, match: pm.Match,
pattern: 'FusedAddRMSNormQuantPattern'):
self.match = match
self.pattern = pattern

def process(self, graph: torch.fx.Graph):
# Find the nodes in the match that we need to rebind
rms_node = find_auto_fn(self.match.nodes,
Expand All @@ -191,8 +256,8 @@ def process(self, graph: torch.fx.Graph):

# First, insert a new auto_functionalized node for the fused op,
# as well as getitem nodes to extract the result and residual.
# The auto_functionalized node returns a tuple of
# (None, result, residual) - None is the function return value.
# The auto_fn node returns a tuple of (None, result, residual).
#
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa
# result_node_new = at[1]
Expand All @@ -212,51 +277,17 @@ def process(self, graph: torch.fx.Graph):
getitem_nodes = self.insert_getitems(graph, fused_node, (1, 2))
result_node_new, residual_node_new = getitem_nodes

# Next, rebind the users of nodes in the match to use the new nodes.
# Find the getitem nodes and replace their uses with the new nodes.
# Rebind the users of match getitem nodes to use the new nodes.
# The old nodes will be removed by DCE at the end of the pass.
find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new)
find_getitem(quant_node, 1).replace_all_uses_with(result_node_new)

# Finally, fix meta["val"] for de-functionalization
# Finally, fix meta["val"] for de-functionalization.
# See MultiOutputMatch.process for more details.
rms_tup, quant_tup = rms_node.meta["val"], quant_node.meta["val"]
# Result of fused node is (None, result, residual)
fused_node.meta["val"] = (None, quant_tup[1], rms_tup[2])

def inserting_after_match(self, graph: torch.fx.Graph):
"""
TODO comment
:param graph:
:return:
"""
# match.nodes is not guaranteed to be sorted.
# Find the last node in the match.
for last_node_in_match in reversed(graph.nodes):
if last_node_in_match in self.match.nodes:
break
else:
raise ValueError("No nodes in graph")

return graph.inserting_after(last_node_in_match)

def insert_getitems(self, graph: torch.fx.Graph,
tuple_node: torch.fx.Node, indices: Tuple[int,
...]):
"""
TODO comment
:param graph:
:param tuple_node:
:param indices:
:return:
"""
with graph.inserting_after(tuple_node):
return [
graph.call_function(operator.getitem, (tuple_node, idx))
for idx in indices
]


class FusionPass(InductorPass):
"""
Expand Down

0 comments on commit 414c451

Please sign in to comment.