From 0b91837469e834190e728248f6c23cc93f9ad6a6 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 29 Oct 2024 15:55:52 +0000 Subject: [PATCH 01/21] Refactor fusion patterns into class Signed-off-by: luka --- vllm/compilation/fusion.py | 375 +++++++++++++++++++++---------------- 1 file changed, 214 insertions(+), 161 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 5efa410fab6a0..a940911361f47 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -1,10 +1,10 @@ import operator -from typing import Iterable, List, Optional +from typing import Callable, Iterable, List, Optional, Tuple import torch +import torch._inductor.pattern_matcher as pm from torch._higher_order_ops.auto_functionalize import auto_functionalized -from torch._inductor.pattern_matcher import (Match, PatternMatcherPass, - fwd_only, register_replacement) +from torch._inductor.pattern_matcher import PatternMatcherPass from vllm.config import CompilationConfig from vllm.logger import init_logger @@ -14,69 +14,6 @@ logger = init_logger(__name__) -def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor, - input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at1 = auto_functionalized(torch.ops._C.rms_norm.default, - result=result_rms, - input=input, - weight=weight, - epsilon=1e-5) - at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, - result=result, - input=at1[1], - scale=scale) - - # result - return at2[1] - - -def rms_replacement_static(result: torch.Tensor, result_rms: torch.Tensor, - input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at = auto_functionalized(torch.ops._C.rms_norm_static_fp8_quant.default, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=1e-5) - - # result - return at[1] - - -def rms_pattern_residual_static(result: torch.Tensor, input: torch.Tensor, - residual: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at = auto_functionalized(torch.ops._C.fused_add_rms_norm.default, - input=input, - residual=residual, - weight=weight, - epsilon=1e-5) - at1 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, - result=result, - input=at[1], - scale=scale) - - # result, residual - return at1[1], at[2] - - -def rms_replacement_residual_static(result: torch.Tensor, input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, scale: torch.Tensor): - at = auto_functionalized( - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - result=result, - input=input, - residual=residual, - weight=weight, - scale=scale, - epsilon=1e-5) - # result, residual - return at[1], at[2] - - def empty_bf16(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") @@ -126,6 +63,202 @@ def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node: return ret +class MultiOutputMatch: + pass + + +class RMSNormQuantPattern: + + def __init__(self, epsilon: float): + self.epsilon = epsilon + + def register(self, pm_pass: PatternMatcherPass): + # Cannot use methods, as the self argument affects tracing + def pattern(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(torch.ops._C.rms_norm.default, + result=result_rms, + input=input, + weight=weight, + epsilon=self.epsilon) + at2 = auto_functionalized( + torch.ops._C.static_scaled_fp8_quant.default, + result=result, + input=at1[1], + scale=scale) + + # result + return at2[1] + + def replacement(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized( + torch.ops._C.rms_norm_static_fp8_quant.default, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon) + + # result + return at[1] + + inputs = [ + empty_fp8(5, 4), # result + empty_bf16(5, 4), # result_rms + empty_bf16(5, 4), # input + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, + pm_pass) + + +class FusedAddRMSNormQuantPattern: + + def __init__(self, epsilon: float): + self.epsilon = epsilon + + def register(self, pm_pass: PatternMatcherPass, + record_match: Callable[[MultiOutputMatch], bool]): + + def pattern(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(torch.ops._C.fused_add_rms_norm.default, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon) + at1 = auto_functionalized( + torch.ops._C.static_scaled_fp8_quant.default, + result=result, + input=at[1], + scale=scale) + + # result, residual + return at1[1], at[2] + + def replacement(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized( + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, + result=result, + input=input, + residual=residual, + weight=weight, + scale=scale, + epsilon=self.epsilon) + + # result, residual + return at[1], at[2] + + inputs = [ + empty_fp8(5, 4), # result + empty_bf16(5, 4), # input + empty_bf16(5, 4), # residual + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + extra_check=lambda m: record_match(self.Match(m, self))) + + class Match(MultiOutputMatch): + + def __init__(self, match: pm.Match, + pattern: 'FusedAddRMSNormQuantPattern'): + self.match = match + self.pattern = pattern + + def process(self, graph: torch.fx.Graph): + # Find the nodes in the match that we need to rebind + rms_node = find_auto_fn(self.match.nodes, + torch.ops._C.fused_add_rms_norm.default) + quant_node = find_auto_fn( + self.match.nodes, torch.ops._C.static_scaled_fp8_quant.default) + + assert len(rms_node.users) == 2 + assert len(quant_node.users) == 1 + + # First, insert a new auto_functionalized node for the fused op, + # as well as getitem nodes to extract the result and residual. + # The auto_functionalized node returns a tuple of + # (None, result, residual) - None is the function return value. + # The resulting graph looks like this: + # at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa + # result_node_new = at[1] + # residual_node_new = at[2] + with self.inserting_after_match(graph): + kwargs = self.match.kwargs.copy() + + # Scalars cannot be inputs to the pattern + kwargs["epsilon"] = rms_node.kwargs["epsilon"] + + fused_node = graph.call_function( + auto_functionalized, + (torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, + ), + kwargs=kwargs) + + getitem_nodes = self.insert_getitems(graph, fused_node, (1, 2)) + result_node_new, residual_node_new = getitem_nodes + + # Next, rebind the users of nodes in the match to use the new nodes. + # Find the getitem nodes and replace their uses with the new nodes. + # The old nodes will be removed by DCE at the end of the pass. + find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) + find_getitem(quant_node, 1).replace_all_uses_with(result_node_new) + + # Finally, fix meta["val"] for de-functionalization + rms_tup, quant_tup = rms_node.meta["val"], quant_node.meta["val"] + # Result of fused node is (None, result, residual) + fused_node.meta["val"] = (None, quant_tup[1], rms_tup[2]) + + def inserting_after_match(self, graph: torch.fx.Graph): + """ + TODO comment + + :param graph: + :return: + """ + # match.nodes is not guaranteed to be sorted. + # Find the last node in the match. + for last_node_in_match in reversed(graph.nodes): + if last_node_in_match in self.match.nodes: + break + else: + raise ValueError("No nodes in graph") + + return graph.inserting_after(last_node_in_match) + + def insert_getitems(self, graph: torch.fx.Graph, + tuple_node: torch.fx.Node, indices: Tuple[int, + ...]): + """ + TODO comment + + :param graph: + :param tuple_node: + :param indices: + :return: + """ + with graph.inserting_after(tuple_node): + return [ + graph.call_function(operator.getitem, (tuple_node, idx)) + for idx in indices + ] + + class FusionPass(VllmInductorPass): """ This pass fuses a pre-defined set of custom ops into fused ops. @@ -158,41 +291,23 @@ def __init__(self, config: CompilationConfig.PassConfig): "FusionPass singleton instance already exists" super().__init__(config) - self.matches: List[Match] = [] + self.matches: List[MultiOutputMatch] = [] self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="fusion_pass") - # Fuse rms_norm + static_scaled_fp8_quant into - # rms_norm_static_fp8_quant - inputs = [ - empty_fp8(5, 4), - empty_bf16(5, 4), - empty_bf16(5, 4), - empty_bf16(1, 5), - empty_fp32(1, 1) - ] - register_replacement(rms_pattern_static, rms_replacement_static, - inputs, fwd_only, self.patterns) + for epsilon in [1e-5]: # TODO figure out how to do multiple epsilons + # Fuse rms_norm + static_scaled_fp8_quant into + # rms_norm_static_fp8_quant + RMSNormQuantPattern(epsilon).register(self.patterns) - # Fuse fused_add_rms_norm + static_scaled_fp8_quant into - # fused_add_rms_norm_static_fp8_quant - # Because pattern has 2 outputs, we need to manually process the match - # (see process_matches) - inputs = [ - empty_fp8(5, 4), - empty_bf16(5, 4), - empty_bf16(5, 4), - empty_bf16(1, 5), - empty_fp32(1, 1) - ] - register_replacement(rms_pattern_residual_static, - rms_replacement_residual_static, - inputs, - fwd_only, - self.patterns, - extra_check=lambda m: self.record_match(m)) - - def record_match(self, match: Match) -> bool: + # Fuse fused_add_rms_norm + static_scaled_fp8_quant into + # fused_add_rms_norm_static_fp8_quant + # Because pattern has 2 outputs, we need to manually process + # the match (see process_matches) + FusedAddRMSNormQuantPattern(epsilon).register( + self.patterns, self.record_match) + + def record_match(self, match: MultiOutputMatch) -> bool: # Hijack the extra_check to record the match and # save it for post-processing. self.matches.append(match) @@ -207,74 +322,12 @@ def process_matches(self, graph: torch.fx.Graph): matches is broken: https://github.com/pytorch/pytorch/issues/137280 """ for match in self.matches: - # To avoid use-before-definition errors, insert replacement nodes - # after the last node in the match. - # match.nodes is not guaranteed to be sorted. - # Find the last node in the match. - for last_node_in_match in reversed(graph.nodes): - if last_node_in_match in match.nodes: - break - else: - raise ValueError("No nodes in graph") - - # Insert a new auto_functionalized node for the fused operation, - # as well as getitem nodes to extract the result and residual. - # The auto_functionalized node returns a tuple of - # (None, result, residual) - None is the function return value. - # The resulting graph looks like this: - # at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa - # result_node_new = at[1] - # residual_node_new = at[2] - with graph.inserting_after(last_node_in_match): - kwargs = match.kwargs - kwargs["epsilon"] = 1e-5 # Currently hard-coded in RMSNorm - - fused_node = graph.call_function( - auto_functionalized, - (torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - ), - kwargs=kwargs) - - graph.inserting_after(fused_node) - result_node_new = graph.call_function(operator.getitem, - (fused_node, 1)) - residual_node_new = graph.call_function( - operator.getitem, (fused_node, 2)) - - # Last part of replacement is rebinding the users of nodes in the - # match to use the new nodes. - - # Find the nodes in the match that we need to rebind - rms_node = find_auto_fn(match.nodes, - torch.ops._C.fused_add_rms_norm.default) - quant_node = find_auto_fn( - match.nodes, torch.ops._C.static_scaled_fp8_quant.default) - - assert len(rms_node.users) == 2 - assert len(quant_node.users) == 1 - - # meta["val"] is used by de-functionalization and has to contain the - # value of the node (tuple of tensors) that would be returned by the - # functionalized node during tracing. - - rms_tup = rms_node.meta["val"] - quant_tup = quant_node.meta["val"] - - # The result of fused_node must be a tuple with the first element - # None (the function return value) and the remaining elements - # representing the mutated inputs. - fused_tup = (None, quant_tup[1], rms_tup[1], rms_tup[2]) - fused_node.meta["val"] = fused_tup - - # Find the getitem nodes and replace their uses with the new nodes. - # The old nodes will be removed by DCE at the end of the pass. - find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) - find_getitem(quant_node, 1).replace_all_uses_with(result_node_new) + match.process(graph) # Finally, remove matched nodes graph.eliminate_dead_code() assert all(node not in graph.nodes for match in self.matches - for node in match.nodes) + for node in match.match.nodes) def __call__(self, graph: torch.fx.Graph): self.begin() From 948680a56207d7a090f6446023732174b0c6ca13 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 29 Oct 2024 16:19:41 +0000 Subject: [PATCH 02/21] Improved comments, move utils to MultiOutputMatch Signed-off-by: luka --- vllm/compilation/fusion.py | 125 +++++++++++++++++++++++-------------- 1 file changed, 78 insertions(+), 47 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index a940911361f47..039654d63769c 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -1,4 +1,6 @@ +import abc import operator +from abc import abstractmethod from typing import Callable, Iterable, List, Optional, Tuple import torch @@ -63,8 +65,76 @@ def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node: return ret -class MultiOutputMatch: - pass +class MultiOutputMatch(abc.ABC): + + def __init__(self, match: pm.Match): + self.match = match + + @property + def nodes(self) -> List[torch.fx.Node]: + return self.match.nodes + + @abstractmethod + def process(self, graph: torch.fx.Graph): + """ + Process a multi-output match and manually insert the replacement. + + This method should: + 1. Insert the replacement nodes after the last node in the match. + 2. Rebind the users of nodes in the match to use the new nodes. + 3. Set meta["val"] for de-functionalization. + + The result of an auto-functionalized node is a tuple of tensors. + The first element is the return value of the function, usually None. + The remaining elements are the mutated args of the function. + + All auto-functionalized nodes must contain a proper meta["val"], + as it is used by de-functionalization. meta["val"] has to contain the + value of the node (tuple of tensors) that would be returned by the + functionalized node during tracing. + + Existing nodes in the graph all have this property set, but we have + to set it manually for new nodes we insert. + + Example: + # op schema: foo(a: Tensor!, b: Tensor, c: Tensor!) -> None + at = auto_functionalized(torch.ops._C.foo.default, a, b, c) + # at.meta["val"] = (None, a, c) + """ + raise NotImplementedError + + def inserting_after_match(self, graph: torch.fx.Graph): + """ + Insert nodes after the last node in the match. + This is done to avoid use-before-definition errors after inserting + replacement nodes. + """ + + # match.nodes is not guaranteed to be sorted. + # Find the last node in the match. + for last_node_in_match in reversed(graph.nodes): + if last_node_in_match in self.match.nodes: + break + else: + raise ValueError("No nodes in graph") + + return graph.inserting_after(last_node_in_match) + + def insert_getitems(self, graph: torch.fx.Graph, tuple_node: torch.fx.Node, + indices: Tuple[int, ...]): + """ + Insert operator.getitem nodes to extract elements from a tuple node. + + :param graph: The graph to insert nodes into. + :param tuple_node: The tuple node to extract elements from. + :param indices: The indices of the elements to extract. + :return: Tuple of the new getitem nodes, corresponding to the indices. + """ + with graph.inserting_after(tuple_node): + return [ + graph.call_function(operator.getitem, (tuple_node, idx)) + for idx in indices + ] class RMSNormQuantPattern: @@ -171,15 +241,10 @@ def replacement(result: torch.Tensor, input: torch.Tensor, inputs, pm.fwd_only, pm_pass, - extra_check=lambda m: record_match(self.Match(m, self))) + extra_check=lambda m: record_match(self.Match(m))) class Match(MultiOutputMatch): - def __init__(self, match: pm.Match, - pattern: 'FusedAddRMSNormQuantPattern'): - self.match = match - self.pattern = pattern - def process(self, graph: torch.fx.Graph): # Find the nodes in the match that we need to rebind rms_node = find_auto_fn(self.match.nodes, @@ -192,8 +257,8 @@ def process(self, graph: torch.fx.Graph): # First, insert a new auto_functionalized node for the fused op, # as well as getitem nodes to extract the result and residual. - # The auto_functionalized node returns a tuple of - # (None, result, residual) - None is the function return value. + # The auto_fn node returns a tuple of (None, result, residual). + # # The resulting graph looks like this: # at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa # result_node_new = at[1] @@ -213,51 +278,17 @@ def process(self, graph: torch.fx.Graph): getitem_nodes = self.insert_getitems(graph, fused_node, (1, 2)) result_node_new, residual_node_new = getitem_nodes - # Next, rebind the users of nodes in the match to use the new nodes. - # Find the getitem nodes and replace their uses with the new nodes. + # Rebind the users of match getitem nodes to use the new nodes. # The old nodes will be removed by DCE at the end of the pass. find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) find_getitem(quant_node, 1).replace_all_uses_with(result_node_new) - # Finally, fix meta["val"] for de-functionalization + # Finally, fix meta["val"] for de-functionalization. + # See MultiOutputMatch.process for more details. rms_tup, quant_tup = rms_node.meta["val"], quant_node.meta["val"] # Result of fused node is (None, result, residual) fused_node.meta["val"] = (None, quant_tup[1], rms_tup[2]) - def inserting_after_match(self, graph: torch.fx.Graph): - """ - TODO comment - - :param graph: - :return: - """ - # match.nodes is not guaranteed to be sorted. - # Find the last node in the match. - for last_node_in_match in reversed(graph.nodes): - if last_node_in_match in self.match.nodes: - break - else: - raise ValueError("No nodes in graph") - - return graph.inserting_after(last_node_in_match) - - def insert_getitems(self, graph: torch.fx.Graph, - tuple_node: torch.fx.Node, indices: Tuple[int, - ...]): - """ - TODO comment - - :param graph: - :param tuple_node: - :param indices: - :return: - """ - with graph.inserting_after(tuple_node): - return [ - graph.call_function(operator.getitem, (tuple_node, idx)) - for idx in indices - ] - class FusionPass(VllmInductorPass): """ From 32d26e91664a6304f7869e96cc650e032a647570 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 29 Oct 2024 16:22:26 +0000 Subject: [PATCH 03/21] Allow multiple epsilons by clearing pattern matcher cache Signed-off-by: luka --- tests/compile/test_fusion.py | 3 --- vllm/compilation/fusion.py | 6 +++++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index f92ec8d0de5f1..034aa4e5f85c1 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -48,9 +48,6 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps): torch.set_default_device("cuda") torch.set_default_dtype(torch.float16) - if eps != 1e-5: - pytest.skip("Only test eps=1e-5 for now") - # Reshape pass is needed for the fusion pass to work config = CompilationConfig.PassConfig(enable_fusion=True, enable_reshape=True) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 039654d63769c..5860227d01960 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -326,7 +326,7 @@ def __init__(self, config: CompilationConfig.PassConfig): self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="fusion_pass") - for epsilon in [1e-5]: # TODO figure out how to do multiple epsilons + for epsilon in [1e-5, 1e-6]: # Fuse rms_norm + static_scaled_fp8_quant into # rms_norm_static_fp8_quant RMSNormQuantPattern(epsilon).register(self.patterns) @@ -338,6 +338,10 @@ def __init__(self, config: CompilationConfig.PassConfig): FusedAddRMSNormQuantPattern(epsilon).register( self.patterns, self.record_match) + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() + def record_match(self, match: MultiOutputMatch) -> bool: # Hijack the extra_check to record the match and # save it for post-processing. From 8c085384ca106e466bcbd8a1d4b5d73c134ed954 Mon Sep 17 00:00:00 2001 From: luka Date: Wed, 30 Oct 2024 19:03:22 +0000 Subject: [PATCH 04/21] Add graph as property of match, add comments, add utilities, extract ops to constants Signed-off-by: luka --- vllm/compilation/fusion.py | 119 ++++++++++++++++++++++--------------- 1 file changed, 70 insertions(+), 49 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 5860227d01960..6aac17eaf8acc 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -66,16 +66,19 @@ def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node: class MultiOutputMatch(abc.ABC): + """ + This class provides utilities to process multi-output matches and + manually insert replacements. + + This is necessary because the automatic replacement for multi-output + matches is broken: https://github.com/pytorch/pytorch/issues/137280 + """ def __init__(self, match: pm.Match): self.match = match - @property - def nodes(self) -> List[torch.fx.Node]: - return self.match.nodes - @abstractmethod - def process(self, graph: torch.fx.Graph): + def process(self): """ Process a multi-output match and manually insert the replacement. @@ -103,7 +106,21 @@ def process(self, graph: torch.fx.Graph): """ raise NotImplementedError - def inserting_after_match(self, graph: torch.fx.Graph): + @property + def nodes(self) -> List[torch.fx.Node]: + return self.match.nodes + + @property + def graph(self) -> torch.fx.Graph: + return self.match.graph + + def find_auto_fn(self, op) -> torch.fx.Node: + """ + Find the first auto_functionalized node with the given op in the match. + """ + return find_auto_fn(self.nodes, op) + + def inserting_after_match(self): """ Insert nodes after the last node in the match. This is done to avoid use-before-definition errors after inserting @@ -112,29 +129,40 @@ def inserting_after_match(self, graph: torch.fx.Graph): # match.nodes is not guaranteed to be sorted. # Find the last node in the match. - for last_node_in_match in reversed(graph.nodes): + for last_node_in_match in reversed(self.graph.nodes): if last_node_in_match in self.match.nodes: break else: raise ValueError("No nodes in graph") - return graph.inserting_after(last_node_in_match) + return self.graph.inserting_after(last_node_in_match) - def insert_getitems(self, graph: torch.fx.Graph, tuple_node: torch.fx.Node, - indices: Tuple[int, ...]): + def insert_getitems(self, tuple_node: torch.fx.Node, + indices: Tuple[int, ...]) -> Tuple[torch.fx.Node, ...]: """ Insert operator.getitem nodes to extract elements from a tuple node. - :param graph: The graph to insert nodes into. :param tuple_node: The tuple node to extract elements from. :param indices: The indices of the elements to extract. :return: Tuple of the new getitem nodes, corresponding to the indices. """ - with graph.inserting_after(tuple_node): - return [ - graph.call_function(operator.getitem, (tuple_node, idx)) - for idx in indices - ] + with self.graph.inserting_after(tuple_node): + return tuple( + self.graph.call_function(operator.getitem, (tuple_node, idx)) + for idx in indices) + + def insert_auto_fn(self, op, kwargs): + """ + Insert an auto_functionalized node with the given op and kwargs. + """ + return self.graph.call_function(auto_functionalized, (op, ), + kwargs=kwargs) + + +RMS_OP = torch.ops._C.rms_norm.default +RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default + +QUANT_STATIC_FP8_OP = torch.ops._C.static_scaled_fp8_quant.default class RMSNormQuantPattern: @@ -142,21 +170,23 @@ class RMSNormQuantPattern: def __init__(self, epsilon: float): self.epsilon = epsilon + +class RMSNormStaticFP8QuantPattern(RMSNormQuantPattern): + def register(self, pm_pass: PatternMatcherPass): # Cannot use methods, as the self argument affects tracing def pattern(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): - at1 = auto_functionalized(torch.ops._C.rms_norm.default, + at1 = auto_functionalized(RMS_OP, result=result_rms, input=input, weight=weight, epsilon=self.epsilon) - at2 = auto_functionalized( - torch.ops._C.static_scaled_fp8_quant.default, - result=result, - input=at1[1], - scale=scale) + at2 = auto_functionalized(QUANT_STATIC_FP8_OP, + result=result, + input=at1[1], + scale=scale) # result return at2[1] @@ -187,10 +217,7 @@ def replacement(result: torch.Tensor, result_rms: torch.Tensor, pm_pass) -class FusedAddRMSNormQuantPattern: - - def __init__(self, epsilon: float): - self.epsilon = epsilon +class FusedAddRMSNormStaticFP8QuantPattern(RMSNormQuantPattern): def register(self, pm_pass: PatternMatcherPass, record_match: Callable[[MultiOutputMatch], bool]): @@ -198,16 +225,15 @@ def register(self, pm_pass: PatternMatcherPass, def pattern(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): - at = auto_functionalized(torch.ops._C.fused_add_rms_norm.default, + at = auto_functionalized(RMS_ADD_OP, input=input, residual=residual, weight=weight, epsilon=self.epsilon) - at1 = auto_functionalized( - torch.ops._C.static_scaled_fp8_quant.default, - result=result, - input=at[1], - scale=scale) + at1 = auto_functionalized(QUANT_STATIC_FP8_OP, + result=result, + input=at[1], + scale=scale) # result, residual return at1[1], at[2] @@ -245,12 +271,10 @@ def replacement(result: torch.Tensor, input: torch.Tensor, class Match(MultiOutputMatch): - def process(self, graph: torch.fx.Graph): + def process(self): # Find the nodes in the match that we need to rebind - rms_node = find_auto_fn(self.match.nodes, - torch.ops._C.fused_add_rms_norm.default) - quant_node = find_auto_fn( - self.match.nodes, torch.ops._C.static_scaled_fp8_quant.default) + rms_node = self.find_auto_fn(RMS_ADD_OP) + quant_node = self.find_auto_fn(QUANT_STATIC_FP8_OP) assert len(rms_node.users) == 2 assert len(quant_node.users) == 1 @@ -263,19 +287,17 @@ def process(self, graph: torch.fx.Graph): # at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa # result_node_new = at[1] # residual_node_new = at[2] - with self.inserting_after_match(graph): + with self.inserting_after_match(): kwargs = self.match.kwargs.copy() # Scalars cannot be inputs to the pattern kwargs["epsilon"] = rms_node.kwargs["epsilon"] - fused_node = graph.call_function( - auto_functionalized, - (torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - ), - kwargs=kwargs) + fused_node = self.insert_auto_fn( + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, + kwargs) - getitem_nodes = self.insert_getitems(graph, fused_node, (1, 2)) + getitem_nodes = self.insert_getitems(fused_node, (1, 2)) result_node_new, residual_node_new = getitem_nodes # Rebind the users of match getitem nodes to use the new nodes. @@ -329,13 +351,13 @@ def __init__(self, config: CompilationConfig.PassConfig): for epsilon in [1e-5, 1e-6]: # Fuse rms_norm + static_scaled_fp8_quant into # rms_norm_static_fp8_quant - RMSNormQuantPattern(epsilon).register(self.patterns) + RMSNormStaticFP8QuantPattern(epsilon).register(self.patterns) # Fuse fused_add_rms_norm + static_scaled_fp8_quant into # fused_add_rms_norm_static_fp8_quant # Because pattern has 2 outputs, we need to manually process # the match (see process_matches) - FusedAddRMSNormQuantPattern(epsilon).register( + FusedAddRMSNormStaticFP8QuantPattern(epsilon).register( self.patterns, self.record_match) # WARNING: This is a hack to clear the pattern matcher cache @@ -353,11 +375,10 @@ def record_match(self, match: MultiOutputMatch) -> bool: def process_matches(self, graph: torch.fx.Graph): """ Manually process multi-output matches and replace them with fused nodes. - This is necessary because the automatic replacement for multi-output - matches is broken: https://github.com/pytorch/pytorch/issues/137280 + See MultiOutputMatch for more details. """ for match in self.matches: - match.process(graph) + match.process() # Finally, remove matched nodes graph.eliminate_dead_code() From 7ea544e55c52db863fc1641f76bb6c448d523a79 Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 31 Oct 2024 15:26:37 +0000 Subject: [PATCH 05/21] dynamic quant (fused ops in python) Signed-off-by: luka --- tests/compile/test_fusion.py | 29 ++- vllm/compilation/fix_functionalization.py | 8 +- vllm/compilation/fusion.py | 238 ++++++++++++++++++++++ 3 files changed, 265 insertions(+), 10 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 034aa4e5f85c1..b75249666d771 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -16,10 +16,15 @@ class TestModel(torch.nn.Module): - def __init__(self, hidden_size: int, eps: float, *args, **kwargs): + def __init__(self, hidden_size: int, eps: float, static: bool, *args, + **kwargs): super().__init__(*args, **kwargs) self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] - self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(4)] + self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] + if static: + self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] + else: + self.scale = [None for _ in range(2)] self.w = [ torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() for _ in range(2) @@ -29,11 +34,11 @@ def forward(self, x): resid = torch.relu(x) y = self.norm[0](x) - x2 = apply_fp8_linear(y, self.w[0], self.scale[0], self.scale[1]) + x2 = apply_fp8_linear(y, self.w[0], self.wscale[0], self.scale[0]) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - x3 = apply_fp8_linear(y2, self.w[1], self.scale[2], self.scale[3]) + x3 = apply_fp8_linear(y2, self.w[1], self.wscale[1], self.scale[1]) y3, resid = self.norm[2](x3, resid) # use resid here return y3 @@ -42,9 +47,10 @@ def forward(self, x): @pytest.mark.parametrize("hidden_size", [64, 3392, 4096]) @pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) +@pytest.mark.parametrize("static", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA") -def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps): +def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static): torch.set_default_device("cuda") torch.set_default_dtype(torch.float16) @@ -55,7 +61,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps): fusion_pass = FusionPass.instance(config) backend = TestBackend(reshape_pass, fusion_pass) - model = TestModel(hidden_size, eps) + model = TestModel(hidden_size, eps, static) # First dimension dynamic x = torch.rand(num_tokens, hidden_size) @@ -73,9 +79,14 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps): pre_nodes = backend.graph_pre_pass.nodes post_nodes = backend.graph_post_pass.nodes - rms_quant = torch.ops._C.rms_norm_static_fp8_quant.default - add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default - fp8_quant = torch.ops._C.static_scaled_fp8_quant.default + if static: + rms_quant = torch.ops._C.rms_norm_static_fp8_quant.default + add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default # noqa: E501 + fp8_quant = torch.ops._C.static_scaled_fp8_quant.default + else: + rms_quant = torch.ops._C.rms_norm_dynamic_fp8_quant.default + add_rms_quant = torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default # noqa: E501 + fp8_quant = torch.ops._C.dynamic_scaled_fp8_quant.default # In pre-nodes, fp8 quant should be present and fused kernels should not assert find_auto_fn_maybe(pre_nodes, rms_quant) is None diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 3584cc3608caf..1c87c7771b21d 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -53,13 +53,16 @@ def __call__(self, graph: torch.fx.Graph): self.insert_defunctionalized(graph, node) self._remove(node) - # These 2 replacements avoid the most copies for LLaMa. + # rms_norm replacements avoid the most copies for LLaMa. elif at_target == torch.ops._C.fused_add_rms_norm.default: mutated_args = {1: 'input', 2: 'residual'} self.defunctionalize(graph, node, mutated_args) elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501 mutated_args = {1: 'result', 2: 'residual'} self.defunctionalize(graph, node, mutated_args) + elif at_target == torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default: # noqa: E501 + mutated_args = {1: 'result', 2: 'residual', 3: 'scale'} + self.defunctionalize(graph, node, mutated_args) elif at_target in [ torch.ops._C.rms_norm.default, @@ -67,6 +70,9 @@ def __call__(self, graph: torch.fx.Graph): ]: mutated_args = {1: 'result'} self.defunctionalize(graph, node, mutated_args) + elif at_target == torch.ops._C.rms_norm_dynamic_fp8_quant.default: + mutated_args = {1: 'result', 2: 'scale'} + self.defunctionalize(graph, node, mutated_args) elif at_target == torch.ops._C.silu_and_mul.default: mutated_args = {1: 'out'} diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 6aac17eaf8acc..e4c8c974623f7 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -16,6 +16,41 @@ logger = init_logger(__name__) +# TODO temp +@torch.library.custom_op("_C::rms_norm_dynamic_fp8_quant", + mutates_args=("result", "scale")) +def rms_norm_dynamic_fp8_quant(result: torch.Tensor, input: torch.Tensor, + weight: torch.Tensor, scale: torch.Tensor, + epsilon: float) -> None: + result_rms = torch.empty_like(input) + torch.ops._C.rms_norm(result_rms, input, weight, epsilon) + torch.ops._C.dynamic_scaled_fp8_quant(result, result_rms, scale) + + +@torch.library.register_fake("_C::rms_norm_dynamic_fp8_quant") +def _(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor, epsilon: float): + return None + + +@torch.library.custom_op("_C::fused_add_rms_norm_dynamic_fp8_quant", + mutates_args=("result", "residual", "scale")) +def fused_add_rms_norm_dynamic_fp8_quant(result: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + epsilon: float) -> None: + torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) + torch.ops._C.dynamic_scaled_fp8_quant(result, input, scale) + + +@torch.library.register_fake("_C::rms_norm_dynamic_fp8_quant") +def _(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor, epsilon: float): + return None + + def empty_bf16(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") @@ -163,6 +198,7 @@ def insert_auto_fn(self, op, kwargs): RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default QUANT_STATIC_FP8_OP = torch.ops._C.static_scaled_fp8_quant.default +QUANT_DYNAMIC_FP8_OP = torch.ops._C.dynamic_scaled_fp8_quant.default class RMSNormQuantPattern: @@ -312,6 +348,198 @@ def process(self): fused_node.meta["val"] = (None, quant_tup[1], rms_tup[2]) +class RMSNormDynamicFP8QuantPattern(RMSNormQuantPattern): + + def register(self, pm_pass: PatternMatcherPass, + record_match: Callable[[MultiOutputMatch], bool]): + + def pattern(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(RMS_OP, + result=result_rms, + input=input, + weight=weight, + epsilon=self.epsilon) + at2 = auto_functionalized(QUANT_DYNAMIC_FP8_OP, + result=result, + input=at1[1], + scale=scale) + + # result, scale + return at2[1], at2[2] + + def replacement(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized( + torch.ops._C.rms_norm_static_fp8_quant.default, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon) + + # result, scale + return at[1], at[2] + + inputs = [ + empty_fp8(5, 4), # result + empty_bf16(5, 4), # result_rms + empty_bf16(5, 4), # input + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + extra_check=lambda m: record_match(self.Match(m))) + + class Match(MultiOutputMatch): + + def process(self): + # Find the nodes in the match that we need to rebind + rms_node = self.find_auto_fn(RMS_OP) + quant_node = self.find_auto_fn(QUANT_DYNAMIC_FP8_OP) + + assert len(rms_node.users) == 1 + assert len(quant_node.users) == 2 + + # First, insert a new auto_functionalized node for the fused op, + # as well as getitem nodes to extract the result and scale. + # The auto_fn node returns a tuple of (None, result, scale). + # + # The resulting graph looks like this: + # at = auto_functionalized(torch.ops._C.rms_norm_static_fp8_quant.default, ...) # noqa + # result_node_new = at[1] + # scale_node_new = at[2] + with self.inserting_after_match(): + kwargs = self.match.kwargs.copy() + + # Scalars cannot be inputs to the pattern + kwargs["epsilon"] = rms_node.kwargs["epsilon"] + del kwargs["result_rms"] # not used in the fused op + + fused_node = self.insert_auto_fn( + torch.ops._C.rms_norm_dynamic_fp8_quant.default, + kwargs=kwargs) + + getitem_nodes = self.insert_getitems(fused_node, (1, 2)) + result_node_new, scale_node_new = getitem_nodes + + # Rebind the users of match getitem nodes to use the new nodes. + # The old nodes will be removed by DCE at the end of the pass. + find_getitem(quant_node, 1).replace_all_uses_with(result_node_new) + find_getitem(quant_node, 2).replace_all_uses_with(scale_node_new) + + # Finally, fix meta["val"] for de-functionalization. + # See MultiOutputMatch.process for more details. + # Result of fused node is (None, result, scale) + fused_node.meta["val"] = quant_node.meta["val"] + + +class FusedAddRMSNormDynamicFP8QuantPattern(RMSNormQuantPattern): + + def register(self, pm_pass: PatternMatcherPass, + record_match: Callable[[MultiOutputMatch], bool]): + + def pattern(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon) + at1 = auto_functionalized(QUANT_DYNAMIC_FP8_OP, + result=result, + input=at[1], + scale=scale) + + # result, residual, scale + return at1[1], at[2], at1[2] + + def replacement(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized( + torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default, + result=result, + input=input, + residual=residual, + weight=weight, + scale=scale, + epsilon=self.epsilon) + + # result, residual, scale + return at[1], at[2], at[3] # TODO confirm signature + + inputs = [ + empty_fp8(5, 4), # result + empty_bf16(5, 4), # input + empty_bf16(5, 4), # residual + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + extra_check=lambda m: record_match(self.Match(m))) + + class Match(MultiOutputMatch): + + def process(self): + # Find the nodes in the match that we need to rebind + rms_node = self.find_auto_fn(RMS_ADD_OP) + quant_node = self.find_auto_fn(QUANT_DYNAMIC_FP8_OP) + + assert len(rms_node.users) == 2 + assert len(quant_node.users) == 2 + + # First, insert a new auto_functionalized node for the fused op, + # as well as getitem nodes to extract result, scale, and residual. + # The auto_fn node returns a tuple (None, result, scale, residual). + # + # The resulting graph looks like this: + # at = auto_functionalized(torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default, ...) # noqa + # result_node_new = at[1] + # residual_node_new = at[2] + # scale_node_new = at[3] + with self.inserting_after_match(): + kwargs = self.match.kwargs.copy() + + # Scalars cannot be inputs to the pattern + kwargs["epsilon"] = rms_node.kwargs["epsilon"] + + fused_node = self.insert_auto_fn( + torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default, + kwargs=kwargs) + + getitem_ns = self.insert_getitems(fused_node, (1, 2, 3)) + result_node_new, residual_node_new, scale_node_new = getitem_ns + + # Rebind the users of match getitem nodes to use the new nodes. + # The old nodes will be removed by DCE at the end of the pass. + find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) + find_getitem(quant_node, 1).replace_all_uses_with(result_node_new) + find_getitem(quant_node, 2).replace_all_uses_with(scale_node_new) + + # Finally, fix meta["val"] for de-functionalization. + # See MultiOutputMatch.process for more details. + rms_tup, quant_tup = rms_node.meta["val"], quant_node.meta["val"] + # Result of fused node is (None, result, scale, residual) + fused_node.meta["val"] = (None, quant_tup[1], quant_tup[2], + rms_tup[2]) + + class FusionPass(VllmInductorPass): """ This pass fuses a pre-defined set of custom ops into fused ops. @@ -360,6 +588,16 @@ def __init__(self, config: CompilationConfig.PassConfig): FusedAddRMSNormStaticFP8QuantPattern(epsilon).register( self.patterns, self.record_match) + # Fuse rms_norm + dynamic_scaled_fp8_quant into + # rms_norm_dynamic_fp8_quant + RMSNormDynamicFP8QuantPattern(epsilon).register( + self.patterns, self.record_match) + + # Fuse fused_add_rms_norm + dynamic_scaled_fp8_quant into + # fused_add_rms_norm_dynamic_fp8_quant + FusedAddRMSNormDynamicFP8QuantPattern(epsilon).register( + self.patterns, self.record_match) + # WARNING: This is a hack to clear the pattern matcher cache # and allow multiple values of epsilon. torch._inductor.pattern_matcher._seen_patterns.clear() From 0678245cfd81f3cb5f12aad26361b1cc15b893a1 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Mon, 5 Aug 2024 15:25:12 +0000 Subject: [PATCH 06/21] add fused-rms-quant-dyn-per-token branch Signed-off-by: luka --- CMakeLists.txt | 3 +- .../fused_kernels/layernorm_rms_benchmarks.py | 173 ++++++++++ csrc/dispatch_utils.h | 7 + csrc/ops.h | 8 + ...fused_layernorm_dynamic_per_token_quant.cu | 160 ++++++++++ .../fused_kernels/layernorm_utils.cuh | 299 ++++++++++++++++++ .../fused_kernels/quant_conversions.cuh | 79 +++++ .../fused_kernels/vectorization.cuh | 27 ++ csrc/torch_bindings.cpp | 8 + tests/kernels/test_fused_quant_layernorm.py | 156 +++++++++ vllm/_custom_ops.py | 21 ++ 11 files changed, 940 insertions(+), 1 deletion(-) create mode 100644 benchmarks/fused_kernels/layernorm_rms_benchmarks.py create mode 100644 csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu create mode 100644 csrc/quantization/fused_kernels/layernorm_utils.cuh create mode 100644 csrc/quantization/fused_kernels/quant_conversions.cuh create mode 100644 csrc/quantization/fused_kernels/vectorization.cuh create mode 100644 tests/kernels/test_fused_quant_layernorm.py diff --git a/CMakeLists.txt b/CMakeLists.txt index c78cdc77a7e42..bf19b3d227171 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -196,6 +196,7 @@ set(VLLM_EXT_SRC "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/common.cu" + "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/gguf/gguf_kernel.cu" "csrc/cuda_utils_kernels.cu" "csrc/prepare_inputs/advance_step.cu" @@ -300,7 +301,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) # kernels for the remaining archs that are not already built for 3x. - cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS + cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS "7.5;8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}") # subtract out the archs that are already built for 3x list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py new file mode 100644 index 0000000000000..4dbdd3638aad3 --- /dev/null +++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py @@ -0,0 +1,173 @@ +import pickle as pkl +import time +from dataclasses import dataclass +from itertools import product +from typing import Callable, Iterable, List, Optional + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from tqdm import tqdm + +import vllm._custom_ops as ops +from vllm.model_executor.layers.layernorm import RMSNorm + + +@dataclass +class bench_params_t: + num_tokens: int + hidden_size: int + add_residual: bool + dtype: torch.dtype + + def description(self): + return (f'N {self.num_tokens} ' + f'x D {self.hidden_size} ' + f'x R {self.add_residual} ' + f'x DT {self.dtype}') + + +def get_bench_params() -> List[bench_params_t]: + ## Test Fixtures + NUM_TOKENS = [2**x for x in range(11)] + HIDDEN_SIZES = list(range(1024, 8129, 1024)) + ADD_RESIDUAL = [True, False] + DTYPES = [torch.bfloat16, torch.float] + + combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES) + bench_params = list(map(lambda x: \ + bench_params_t(x[0], x[1], x[2], x[3]), combinations)) + return bench_params + + +# Reference impls +def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype): + # Norm + torch_out = None + if residual is None: + torch_out = rms_norm_layer.forward_cuda(x, residual) + else: + torch_out, _ = rms_norm_layer.forward_cuda(x, residual) + + # Quant + torch_out, _ = ops.scaled_int8_quant(torch_out) + + +def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype): + # Norm + torch_out = None + if residual is None: + torch_out = rms_norm_layer.forward_cuda(x, residual) + else: + torch_out, _ = rms_norm_layer.forward_cuda(x, residual) + + # Quant + torch_out, _ = ops.scaled_fp8_quant(torch_out) + + +def fused_impl( + rms_norm_layer: RMSNorm, # this stores the weights + x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype): + out, _ = ops.rms_norm_dynamic_per_token_quant(x, + rms_norm_layer.weight, + 1e-6, + quant_dtype, + residual=residual) + + +# Bench functions +def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor, + quant_dtype: torch.dtype, label: str, sub_label: str, + fn: Callable, description: str) -> TMeasurement: + + min_run_time = 1 + + globals = { + "rms_norm_layer": rms_norm_layer, + "x": x, + "residual": residual, + "quant_dtype": quant_dtype, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(rms_norm_layer, x, residual, quant_dtype)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + +def bench(params: bench_params_t, label: str, sub_label: str) \ + -> Iterable[TMeasurement]: + + # Make inputs + layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype) + # Make weights + layer.weight.data.normal_(mean=1.0, std=0.1) + # Make inputs + scale = 1 / params.hidden_size + x = torch.randn(params.num_tokens, + params.hidden_size, + dtype=params.dtype, + device='cuda') * scale + residual = (torch.randn_like(x) * scale).to(device='cuda') \ + if params.add_residual else None + + timers = [] + + # unfused int8 impl. + timers.append( + bench_fn(layer, x, residual, torch.int8, label, sub_label, + unfused_int8_impl, "unfused_int8_impl")) + + # unfused fp8 impl. + timers.append( + bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, + unfused_fp8_impl, "unfused_fp8_impl")) + + # fused int8 impl. + timers.append( + bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl, + "fused_int8_impl")) + + # fused fp8 impl. + timers.append( + bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, + fused_impl, "fused_fp8_impl")) + + print_timers(timers) + + return timers + + +# launch bench +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def main(): + torch.set_default_device('cuda') + bench_params = get_bench_params() + + timers = [] + for bp in tqdm(bench_params): + timers.extend( + bench(bp, "rms-norm-dynamic-per-token-quant", bp.description())) + print_timers(timers) + + # pickle all the results + timestamp = int(time.time()) + with open(f"rms_norm_dpt_quant-{timestamp}.pkl", "wb") as f: + pkl.dump(timers, f) + + +if __name__ == '__main__': + main() diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index a634e1c3d4886..aa5c8dbbae182 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -14,6 +14,13 @@ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) + +#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) + #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ diff --git a/csrc/ops.h b/csrc/ops.h index ea001190bc202..816b471d062d2 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -66,6 +66,14 @@ void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& weight, torch::Tensor& scale, double epsilon); +void rms_norm_dynamic_per_token_quant(torch::Tensor& out, + torch::Tensor const& input, + torch::Tensor const& weight, + torch::Tensor& scales, + double const epsilon, + std::optional scale_ub, + std::optional residual); + void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu new file mode 100644 index 0000000000000..38d2e3703672c --- /dev/null +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -0,0 +1,160 @@ + +#include +#include + +#include "../../dispatch_utils.h" +#include "../../reduction_utils.cuh" +#include "layernorm_utils.cuh" +#include "quant_conversions.cuh" + +namespace vllm { + +template +__device__ void rms_norm_dynamic_per_token_quant_vec( + scalar_out_t* __restrict__ out, // [..., hidden_size] + float* __restrict__ scales, // [num_tokens] + scalar_t const* __restrict__ input, // [..., hidden_size] + scalar_t const* __restrict__ weight, // [hidden_size] + float const* scale_ub, float const var_epsilon, + float const min_scaling_factor, int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr) { + float rms = 0.0f; + float token_scale = 0.0f; + + // Compute rms + vllm::vectorized::compute_rms( + &rms, input, hidden_size, var_epsilon, residual); + + // Compute scale + vllm::vectorized::compute_dynamic_per_token_scales( + &token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor, + hidden_size, residual); + + // RMS Norm + Quant + if constexpr (std::is_same_v) { + vllm::vectorized::norm_and_quant( + out, input, weight, rms, 1.0f / token_scale, hidden_size, residual); + } else { + // FP8 - Do not invert s_token_scale for exact match with FBGemm + vllm::vectorized::norm_and_quant( + out, input, weight, rms, token_scale, hidden_size, residual); + } +} + +// RMS norm + quant kernel +template +__global__ void rms_norm_dynamic_per_token_quant_kernel( + scalar_out_t* __restrict__ out, // [..., hidden_size] + float* __restrict__ scales, // [num_tokens] + scalar_t const* __restrict__ input, // [..., hidden_size] + scalar_t const* __restrict__ weight, // [hidden_size] + float const* scale_ub, float const var_epsilon, + float const min_scaling_factor, int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr) { + // For vectorization, token_input and token_output pointers need to be + // aligned at 8-byte and 4-byte addresses respectively. + bool const can_vectorize = hidden_size % 4 == 0; + + if (can_vectorize) { + return rms_norm_dynamic_per_token_quant_vec( + out, scales, input, weight, scale_ub, var_epsilon, min_scaling_factor, + hidden_size, residual); + } + + float rms = 0.0f; + float token_scale = 0.0f; + + // Compute RMS + vllm::compute_rms(&rms, input, hidden_size, + var_epsilon, residual); + // Compute Scale + vllm::compute_dynamic_per_token_scales( + &token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor, + hidden_size, residual); + + // RMS Norm + Quant + if constexpr (std::is_same_v) { + vllm::norm_and_quant( + out, input, weight, rms, 1.0f / token_scale, hidden_size, residual); + } else { + // FP8 - Do not invert s_token_scale for exact match with FBGemm + vllm::norm_and_quant( + out, input, weight, rms, token_scale, hidden_size, residual); + } +} +} // namespace vllm + +// Residual add + RMS norm + dynamic per token +template +void rms_norm_dynamic_per_token_quant_dispatch( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] + torch::Tensor const& weight, // [hidden_size] + torch::Tensor& scales, // [num_tokens] + double const var_epsilon, // Variance epsilon used in norm calculation + std::optional const& scale_ub, + std::optional& residual) { + int32_t hidden_size = input.size(-1); + int32_t num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const float min_scaling_factor = + out.dtype() == torch::kInt8 + ? std::numeric_limits::epsilon() + : 1.0f / (std::numeric_limits::max() * 512.f); + + if (residual.has_value()) { + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] { + vllm::rms_norm_dynamic_per_token_quant_kernel + <<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + var_epsilon, min_scaling_factor, hidden_size, + residual->data_ptr()); + }); + + } else { + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] { + vllm::rms_norm_dynamic_per_token_quant_kernel + <<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + var_epsilon, min_scaling_factor, hidden_size, nullptr); + }); + } +} + +void rms_norm_dynamic_per_token_quant( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] + torch::Tensor const& weight, // [hidden_size] + torch::Tensor& scales, // [num_tokens] + double const var_epsilon, // Variance epsilon used in norm calculation + std::optional scale_ub, std::optional residual) { + TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn || + out.dtype() == torch::kInt8); + + if (scale_ub.has_value()) { + TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn); + } + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] { + rms_norm_dynamic_per_token_quant_dispatch( + out, input, weight, scales, var_epsilon, scale_ub, residual); + }); +} diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh new file mode 100644 index 0000000000000..395b3fa6ba7b7 --- /dev/null +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -0,0 +1,299 @@ +#pragma once + +/** + * __device__ layernorm utilities. + */ + +#include "vectorization.cuh" +#include "quant_conversions.cuh" + +namespace vllm { + +// has_residual must be true, if residual is not a nullptr +template +__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, + int32_t const hidden_size, float const epsilon, + scalar_t const* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * hidden_size; + // sum of squares + float ss = 0.0f; + + for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float x = static_cast(input[token_offset + i]); + if constexpr (has_residual) { + x += static_cast(residual[token_offset + i]); + } + + ss += x * x; + } + ss = blockReduceSum(ss); + __shared__ float s_rms; + if (threadIdx.x == 0) { + s_rms = rsqrtf(ss / hidden_size + epsilon); + } + __syncthreads(); + + *rms = s_rms; +} + +template +__device__ void compute_dynamic_per_token_scales( + float* __restrict__ token_scale, float* __restrict__ all_token_scales, + scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, + float const rms, float const* __restrict__ scale_ub, + float const min_scaling_factor, int32_t const hidden_size, + scalar_t const* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * hidden_size; + constexpr scalar_out_t qmax{std::numeric_limits::max()}; + + float block_absmax_val_maybe = 0.0f; + for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float x = static_cast(input[token_offset + i]); + if constexpr (has_residual) { + x += static_cast(residual[token_offset + i]); + } + + x = static_cast(static_cast(x * rms) * weight[i]); + block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); + } + block_absmax_val_maybe = blockReduceMax(block_absmax_val_maybe); + + __shared__ float s_token_scale; + if (threadIdx.x == 0) { + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor); + s_token_scale = scale; // Shared memory store + all_token_scales[blockIdx.x] = scale; // Global output store + } + __syncthreads(); + + *token_scale = s_token_scale; +} + +template +__device__ void norm_and_quant(scalar_out_t* __restrict__ output, + scalar_t const* __restrict__ input, + scalar_t const* __restrict__ weight, + float const rms, float const scale, + int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * hidden_size; + + for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float x = static_cast(input[token_offset + i]); + if constexpr (has_residual) { + x += static_cast(residual[token_offset + i]); + residual[token_offset + i] = static_cast(x); + } + // Norm + x = static_cast(static_cast(x * rms) * weight[i]); + // Quant + output[token_offset + i] = + ScaledQuant::quant_fn(x, scale); + } +} + +namespace vectorized { + +// Compute 1.0/rms(input) +// hidden_size must be a multiple of 4 +template +__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, + int32_t const hidden_size, float const epsilon, + scalar_t const* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * hidden_size; + + // Vectorized input/output to better utilize memory bandwidth. + vec4_t const* vec_input = + reinterpret_cast const*>(&input[token_offset]); + vec4_t const* vec_residual = nullptr; + if constexpr (has_residual) { + vec_residual = + reinterpret_cast const*>(&residual[token_offset]); + } + + // sum of squares + float ss = 0.0f; + + int32_t const num_vec_elems = hidden_size >> 2; + +#pragma unroll 4 + for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { + vec4_t in = vec_input[i]; + + vec4_t x; + x.x = static_cast(in.x); + x.y = static_cast(in.y); + x.z = static_cast(in.z); + x.w = static_cast(in.w); + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; + x.x += static_cast(r.x); + x.y += static_cast(r.y); + x.z += static_cast(r.z); + x.w += static_cast(r.w); + } + + ss += x.x * x.x; + ss += x.y * x.y; + ss += x.z * x.z; + ss += x.w * x.w; + } + + ss = blockReduceSum(ss); + __shared__ float s_rms; + if (threadIdx.x == 0) { + s_rms = rsqrtf(ss / hidden_size + epsilon); + } + __syncthreads(); + + *rms = s_rms; +} + +// Vectorized version of vllm::compute_dynamic_per_token_scales +// hidden_size must be a multiple of 4 +template +__device__ void compute_dynamic_per_token_scales( + float* __restrict__ token_scale, float* __restrict__ all_token_scales, + scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, + float const rms, float const* __restrict__ scale_ub, + float const min_scaling_factor, int32_t const hidden_size, + scalar_t const* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * hidden_size; + + // Vectorized input/weight/residual to better utilize memory bandwidth. + vec4_t const* vec_input = + reinterpret_cast const*>(&input[token_offset]); + vec4_t const* vec_weight = + reinterpret_cast const*>(weight); + vec4_t const* vec_residual = nullptr; + if constexpr (has_residual) { + vec_residual = + reinterpret_cast const*>(&residual[token_offset]); + } + + constexpr scalar_out_t qmax{std::numeric_limits::max()}; + + int32_t const num_vec_elems = hidden_size >> 2; + float block_absmax_val_maybe = 0.0f; + +#pragma unroll 4 + for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { + vec4_t in = vec_input[i]; + vec4_t const w = vec_weight[i]; + + vec4_t x; + x.x = static_cast(in.x); + x.y = static_cast(in.y); + x.z = static_cast(in.z); + x.w = static_cast(in.w); + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; + x.x += static_cast(r.x); + x.y += static_cast(r.y); + x.z += static_cast(r.z); + x.w += static_cast(r.w); + } + + block_absmax_val_maybe = fmaxf( + block_absmax_val_maybe, fabs(static_cast(x.x * rms) * w.x)); + block_absmax_val_maybe = fmaxf( + block_absmax_val_maybe, fabs(static_cast(x.y * rms) * w.y)); + block_absmax_val_maybe = fmaxf( + block_absmax_val_maybe, fabs(static_cast(x.z * rms) * w.z)); + block_absmax_val_maybe = fmaxf( + block_absmax_val_maybe, fabs(static_cast(x.w * rms) * w.w)); + } + + block_absmax_val_maybe = blockReduceMax(block_absmax_val_maybe); + + __shared__ float s_token_scale; + if (threadIdx.x == 0) { + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor); + s_token_scale = scale; // shared memory store + all_token_scales[blockIdx.x] = scale; // global output store + } + __syncthreads(); + + *token_scale = s_token_scale; +} + +// hidden_size must be a multiple of 4 +template +__device__ void norm_and_quant(scalar_out_t* __restrict__ output, + scalar_t const* __restrict__ input, + scalar_t const* __restrict__ weight, + float const rms, float const scale, + int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * hidden_size; + + // Vectorized input/output/weight/residual to better utilize memory bandwidth. + vec4_t const* vec_input = + reinterpret_cast const*>(&input[token_offset]); + vec4_t const* vec_weight = + reinterpret_cast const*>(weight); + q8x4_t* vec_output = + reinterpret_cast*>(&output[token_offset]); + vec4_t* vec_residual = nullptr; + if constexpr (has_residual) { + vec_residual = reinterpret_cast*>(&residual[token_offset]); + } + + int32_t const num_vec_elems = hidden_size >> 2; + +#pragma unroll 4 + for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { + vec4_t const in = vec_input[i]; + vec4_t const w = vec_weight[i]; + + vec4_t x; + x.x = static_cast(in.x); + x.y = static_cast(in.y); + x.z = static_cast(in.z); + x.w = static_cast(in.w); + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; + x.x += static_cast(r.x); + x.y += static_cast(r.y); + x.z += static_cast(r.z); + x.w += static_cast(r.w); + // Update residual + r.x = static_cast(x.x); + r.y = static_cast(x.y); + r.z = static_cast(x.z); + r.w = static_cast(x.w); + vec_residual[i] = r; + } + + q8x4_t out; + out.x = ScaledQuant::quant_fn( + static_cast(x.x * rms) * w.x, scale); + out.y = ScaledQuant::quant_fn( + static_cast(x.y * rms) * w.y, scale); + out.z = ScaledQuant::quant_fn( + static_cast(x.z * rms) * w.z, scale); + out.w = ScaledQuant::quant_fn( + static_cast(x.w * rms) * w.w, scale); + vec_output[i] = out; + } +} + +} // namespace vectorized + +} // namespace vllm diff --git a/csrc/quantization/fused_kernels/quant_conversions.cuh b/csrc/quantization/fused_kernels/quant_conversions.cuh new file mode 100644 index 0000000000000..abf32f40c4b0e --- /dev/null +++ b/csrc/quantization/fused_kernels/quant_conversions.cuh @@ -0,0 +1,79 @@ +#pragma once + +/** + * __device__ helper functions to deal with float -> quant datatype conversion + */ + +#include "vectorization.cuh" + +namespace vllm { + +static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) { +#ifdef USE_ROCM + static const float i8_min = + static_cast(std::numeric_limits::min()); + static const float i8_max = + static_cast(std::numeric_limits::max()); + // round + float dst = std::nearbyint(x); + // saturate + dst = std::clamp(dst, i8_min, i8_max); + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +#endif +} + +#define FP8_E4M3_MAX std::numeric_limits::max() +static __device__ __forceinline__ c10::Float8_e4m3fn float_to_fp8( + float const x) { + float const r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); + return static_cast(r); +} + +template +struct ScaledQuant; + +template +struct ScaledQuant< + quant_type_t, is_scale_inverted, + typename std::enable_if_t>> { + static __device__ __forceinline__ quant_type_t quant_fn(float const x, + float const scale) { + if constexpr (is_scale_inverted) { + return float_to_int8_rn(x * scale); + } else { + return float_to_int8_rn(x / scale); + } + } +}; + +template +struct ScaledQuant>> { + static __device__ __forceinline__ quant_type_t quant_fn(float const x, + float const scale) { + if constexpr (is_scale_inverted) { + return float_to_fp8(x * scale); + } else { + return float_to_fp8(x / scale); + } + } +}; + +template +__device__ void scaled_quant_conversion(quant_type_t* __restrict__ output, + scalar_t const* __restrict__ input, + float const scale, int const tid, + int const num_elements, + int const step) { + for (int i = tid; i < num_elements; i += step) { + output[i] = ScaledQuant(input[i], scale); + } +} + +} // namespace vllm diff --git a/csrc/quantization/fused_kernels/vectorization.cuh b/csrc/quantization/fused_kernels/vectorization.cuh new file mode 100644 index 0000000000000..7ba0df6b11ce4 --- /dev/null +++ b/csrc/quantization/fused_kernels/vectorization.cuh @@ -0,0 +1,27 @@ +#pragma once +/** + * __device__ algorithms that perform vectorized loads/stores of input/output. + */ + +namespace vllm { + +// Vectorization containers +template +struct __align__(8) vec4_t { + scalar_t x; + scalar_t y; + scalar_t z; + scalar_t w; +}; + +template +struct __align__(4) q8x4_t { + static_assert(std::is_same_v || + std::is_same_v); + quant_type_t x; + quant_type_t y; + quant_type_t z; + quant_type_t w; +}; + +} // namespace vllm diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 4e64b9c92773a..818ad231e1543 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -128,6 +128,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA, &fused_add_rms_norm_static_fp8_quant); + // Fused Layernorm + Quant kernels + ops.def( + "rms_norm_dynamic_per_token_quant(Tensor! out, Tensor input, " + "Tensor weight, Tensor! scales, float epsilon, " + "Tensor? scale_ub, Tensor!? residual) -> ()"); + ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA, + &rms_norm_dynamic_per_token_quant); + // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ops.def( diff --git a/tests/kernels/test_fused_quant_layernorm.py b/tests/kernels/test_fused_quant_layernorm.py new file mode 100644 index 0000000000000..ec3ad5ab5dd8b --- /dev/null +++ b/tests/kernels/test_fused_quant_layernorm.py @@ -0,0 +1,156 @@ +from typing import Optional, Tuple, Union + +import pytest +import torch + +import vllm._custom_ops as ops +from vllm.model_executor.layers.layernorm import RMSNorm + +DTYPES = [torch.bfloat16, torch.float] +QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn] +NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing +HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192, + 8193] # Arbitrary values for testing +HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases +ADD_RESIDUAL = [False, True] +SCALE_UBS = [True, False] +SEEDS = [0] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + +EPS = 1e-6 + +## Helpers + + +def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: + return torch.as_tensor(x, dtype=torch.float32, device='cuda') + +def ref_rms_norm(rms_norm_layer: RMSNorm, + x: torch.Tensor, + residual: Optional[torch.Tensor]) \ + -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + + if residual is not None: + residual = residual.clone() + out, residual = rms_norm_layer.forward_native(x, residual) + else: + out = rms_norm_layer.forward_native(x) + + return out, residual + +def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor]) \ + -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + + if scale_ub is not None: + assert quant_dtype == torch.float8_e4m3fn + + # Norm + torch_out, residual = ref_rms_norm(rms_norm_layer, x, residual) + + # Quant + if quant_dtype == torch.float8_e4m3fn: + torch_out, scales = ops.scaled_fp8_quant(torch_out, + scale_ub=scale_ub, + use_per_token_if_dynamic=True) + else: + assert quant_dtype == torch.int8 + torch_out, scales = ops.scaled_int8_quant(torch_out) + + return torch_out, scales, residual + +def ref_impl(rms_norm_layer: RMSNorm, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor]) \ + -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype, + residual, scale_ub) + +def ops_dynamic_per_token_quant(weight: torch.Tensor, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor]) \ + -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + + if residual is not None: + residual = residual.clone() + out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS, + quant_dtype, scale_ub, + residual) + return out, scales, residual + +def ops_impl(weight: torch.Tensor, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor]) \ + -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, + scale_ub) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("add_residual", ADD_RESIDUAL) +@pytest.mark.parametrize("scale_ub", SCALE_UBS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_rms_norm( + num_tokens: int, + hidden_size: int, + add_residual: bool, + scale_ub: bool, + dtype: torch.dtype, + quant_dtype: torch.dtype, + seed: int, + device: str, +) -> None: + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + + if scale_ub is not None and quant_dtype != torch.float8_e4m3fn: + # skip + return + + layer = RMSNorm(hidden_size, EPS).to(dtype=dtype) + + # Make weights + layer.weight.data.normal_(mean=1.0, std=0.1) + + # Make inputs + scale = 1 / (hidden_size) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale + residual = torch.randn_like(x) * scale if add_residual else None + if scale_ub is not None: + rms_x, _ = ref_rms_norm(layer, x, residual) + scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device='cuda') + + ref_out, ref_scales, ref_residual = \ + ref_impl(layer, x, quant_dtype, residual, scale_ub) + ops_out, ops_scales, ops_residual = \ + ops_impl(layer.weight, x, quant_dtype, residual, scale_ub) + + assert ref_out.dtype == quant_dtype + assert ops_out.dtype == quant_dtype + assert torch.allclose(ref_scales, ops_scales) + if quant_dtype == torch.int8: + # big atol to account for round-off errors. + assert torch.allclose(ref_out, ops_out, atol=1) + else: + assert torch.allclose(ref_out.to(dtype=torch.float32), + ops_out.to(dtype=torch.float32)) + if add_residual: + assert torch.allclose(ref_residual, ops_residual) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index c192c9a7b0e4d..f72c82673cad5 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -249,6 +249,27 @@ def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int, block_table_bound) +# fused quant layer norm ops +def rms_norm_dynamic_per_token_quant( + input: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + quant_dtype: torch.dtype, + scale_ub: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + + output = torch.empty_like(input, dtype=quant_dtype) + scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + + torch.ops._C.rms_norm_dynamic_per_token_quant(output, input, weight, + scales, epsilon, scale_ub, + residual) + return output, scales + + # quantization ops # awq def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, From 62bf187b2d2eb7b06a4c58304f8476921b6c64b7 Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 21 Nov 2024 22:01:44 +0000 Subject: [PATCH 07/21] Upgrade reduction Signed-off-by: luka --- ...fused_layernorm_dynamic_per_token_quant.cu | 1 - .../fused_kernels/layernorm_utils.cuh | 30 ++++++++++++++++--- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 38d2e3703672c..e4c9d66c33b40 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -3,7 +3,6 @@ #include #include "../../dispatch_utils.h" -#include "../../reduction_utils.cuh" #include "layernorm_utils.cuh" #include "quant_conversions.cuh" diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index 395b3fa6ba7b7..05ed221930dc8 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -7,6 +7,12 @@ #include "vectorization.cuh" #include "quant_conversions.cuh" +#ifndef USE_ROCM + #include +#else + #include +#endif + namespace vllm { // has_residual must be true, if residual is not a nullptr @@ -26,7 +32,11 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, ss += x * x; } - ss = blockReduceSum(ss); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x); + __shared__ float s_rms; if (threadIdx.x == 0) { s_rms = rsqrtf(ss / hidden_size + epsilon); @@ -56,7 +66,12 @@ __device__ void compute_dynamic_per_token_scales( x = static_cast(static_cast(x * rms) * weight[i]); block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); } - block_absmax_val_maybe = blockReduceMax(block_absmax_val_maybe); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + block_absmax_val_maybe = + BlockReduce(reduceStore) + .Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x); __shared__ float s_token_scale; if (threadIdx.x == 0) { @@ -147,7 +162,10 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, ss += x.w * x.w; } - ss = blockReduceSum(ss); + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x); + __shared__ float s_rms; if (threadIdx.x == 0) { s_rms = rsqrtf(ss / hidden_size + epsilon); @@ -212,7 +230,11 @@ __device__ void compute_dynamic_per_token_scales( block_absmax_val_maybe, fabs(static_cast(x.w * rms) * w.w)); } - block_absmax_val_maybe = blockReduceMax(block_absmax_val_maybe); + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + block_absmax_val_maybe = + BlockReduce(reduceStore) + .Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x); __shared__ float s_token_scale; if (threadIdx.x == 0) { From a3a2b69272f1ac1e80099bc54f80c120de8298e8 Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 21 Nov 2024 22:02:18 +0000 Subject: [PATCH 08/21] Use new dynamic ops for fusion, tolerance has to be higher. Signed-off-by: luka --- .../fused_kernels/layernorm_rms_benchmarks.py | 2 +- csrc/torch_bindings.cpp | 2 +- tests/compile/test_fusion.py | 5 ++-- vllm/_custom_ops.py | 29 +++++++++++++++---- vllm/compilation/fusion.py | 9 +++--- 5 files changed, 32 insertions(+), 15 deletions(-) diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py index 4dbdd3638aad3..ef91f9f8eb529 100644 --- a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py +++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py @@ -52,7 +52,7 @@ def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, torch_out, _ = rms_norm_layer.forward_cuda(x, residual) # Quant - torch_out, _ = ops.scaled_int8_quant(torch_out) + torch_out, _, _ = ops.scaled_int8_quant(torch_out) def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 818ad231e1543..d5de65a5c9e30 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -130,7 +130,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Fused Layernorm + Quant kernels ops.def( - "rms_norm_dynamic_per_token_quant(Tensor! out, Tensor input, " + "rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, " "Tensor weight, Tensor! scales, float epsilon, " "Tensor? scale_ub, Tensor!? residual) -> ()"); ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA, diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index b75249666d771..d30ee82117866 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -72,8 +72,9 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static): model2 = torch.compile(model, backend=backend) result2 = model2(x) - # Check that it gives the same answer - torch.testing.assert_close(result, result2, atol=1e-3, rtol=1e-3) + # Check that it gives the same answer, higher tol for dynamic + ATOL, RTOL = (1e-3, 1e-3) if static else (1e-2, 1e-2) + torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) # Check substitution worked pre_nodes = backend.graph_pre_pass.nodes diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index f72c82673cad5..8503452df15df 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -22,6 +22,7 @@ supports_moe_ops = False with contextlib.suppress(ImportError): import vllm._moe_C # noqa: F401 + supports_moe_ops = True # neuron has torch version that doesn't even have impl_abstract @@ -241,7 +242,6 @@ def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int, paged_kv_indptr: torch.Tensor, paged_kv_last_page_len: torch.Tensor, block_table_bound: torch.Tensor) -> None: - return torch.ops._C.advance_step_flashinfer( num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, input_positions, seq_lens, slot_mapping, block_tables, @@ -258,7 +258,6 @@ def rms_norm_dynamic_per_token_quant( scale_ub: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: - output = torch.empty_like(input, dtype=quant_dtype) scales = torch.empty((input.numel() // input.shape[-1], 1), device=input.device, @@ -270,6 +269,24 @@ def rms_norm_dynamic_per_token_quant( return output, scales +# TODO is this necessary? +@register_fake("_C::rms_norm_dynamic_per_token_quant") +def _rms_norm_dynamic_per_token_quant_fake( + input: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + quant_dtype: torch.dtype, + scale_ub: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + output = torch.empty_like(input, dtype=quant_dtype) + scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + + return output, scales + + # quantization ops # awq def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, @@ -726,7 +743,7 @@ def scaled_fp8_quant( shape: Union[Tuple[int, int], torch.Size] = input.shape # For rocm, the output fp8 dtype is torch.float_e3m3fnuz out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - if current_platform.is_rocm() else torch.float8_e4m3fn + if current_platform.is_rocm() else torch.float8_e4m3fn if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) output = torch.empty(shape, device=input.device, dtype=out_dtype) @@ -1009,9 +1026,9 @@ def register_graph_buffers(fa: int, handles: List[List[int]], # the case when users use `import __annotations__` to turn type # hints into strings. if isinstance(v, fn_type) \ - and v.__code__.co_filename == __file__ \ - and any(arg is torch.Tensor or arg == "torch.Tensor" - for arg in v.__annotations__.values()): + and v.__code__.co_filename == __file__ \ + and any(arg is torch.Tensor or arg == "torch.Tensor" + for arg in v.__annotations__.values()): names_and_values_to_update[k] = hint_on_error(v) names_and_values.update(names_and_values_to_update) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index e4c8c974623f7..fa9b02ff13ea6 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -22,9 +22,8 @@ def rms_norm_dynamic_fp8_quant(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, epsilon: float) -> None: - result_rms = torch.empty_like(input) - torch.ops._C.rms_norm(result_rms, input, weight, epsilon) - torch.ops._C.dynamic_scaled_fp8_quant(result, result_rms, scale) + # Last two are scale_ub, residual + torch.ops._C.rms_norm_dynamic_per_token_quant(result, input, weight, scale, epsilon, None, None) @torch.library.register_fake("_C::rms_norm_dynamic_fp8_quant") @@ -41,8 +40,8 @@ def fused_add_rms_norm_dynamic_fp8_quant(result: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, epsilon: float) -> None: - torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) - torch.ops._C.dynamic_scaled_fp8_quant(result, input, scale) + # Last two are scale_ub, residual + torch.ops._C.rms_norm_dynamic_per_token_quant(result, input, weight, scale, epsilon, None, residual) @torch.library.register_fake("_C::rms_norm_dynamic_fp8_quant") From 7945e6217fbbcbbb23c65c54e60ffb5a07a15441 Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 21 Nov 2024 23:34:07 +0000 Subject: [PATCH 09/21] In progress dynamic fusion debugging --- csrc/torch_bindings.cpp | 2 +- tests/compile/test_fusion.py | 20 +++--- tests/kernels/test_fused_quant_layernorm.py | 4 +- vllm/_custom_ops.py | 21 +++---- vllm/compilation/fix_functionalization.py | 8 +-- vllm/compilation/fusion.py | 70 +++++++-------------- 6 files changed, 47 insertions(+), 78 deletions(-) diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index d5de65a5c9e30..1ffab14862fed 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -131,7 +131,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Fused Layernorm + Quant kernels ops.def( "rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, " - "Tensor weight, Tensor! scales, float epsilon, " + "Tensor weight, Tensor! scale, float epsilon, " "Tensor? scale_ub, Tensor!? residual) -> ()"); ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA, &rms_norm_dynamic_per_token_quant); diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index d30ee82117866..fa1765d6ad84a 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -52,11 +52,14 @@ def forward(self, x): reason="Only test on CUDA") def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static): torch.set_default_device("cuda") - torch.set_default_dtype(torch.float16) + torch.set_default_dtype(dtype) + torch.manual_seed(1) # Reshape pass is needed for the fusion pass to work - config = CompilationConfig.PassConfig(enable_fusion=True, - enable_reshape=True) + config = CompilationConfig.PassConfig( + enable_fusion=True, + enable_reshape=True, + dump_graph_stages=["before_fusion", "after_fusion"]) reshape_pass = RedundantReshapesPass(config) fusion_pass = FusionPass.instance(config) @@ -73,8 +76,11 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static): result2 = model2(x) # Check that it gives the same answer, higher tol for dynamic - ATOL, RTOL = (1e-3, 1e-3) if static else (1e-2, 1e-2) - torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) + ATOL, RTOL = (1e-3, 1e-3) if static else (2e-2, 2e-2) + torch.testing.assert_close(result.to(dtype=torch.float32), + result2.to(dtype=torch.float32), + atol=ATOL, + rtol=RTOL) # Check substitution worked pre_nodes = backend.graph_pre_pass.nodes @@ -85,8 +91,8 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static): add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default # noqa: E501 fp8_quant = torch.ops._C.static_scaled_fp8_quant.default else: - rms_quant = torch.ops._C.rms_norm_dynamic_fp8_quant.default - add_rms_quant = torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default # noqa: E501 + rms_quant = torch.ops._C.rms_norm_dynamic_per_token_quant.default + add_rms_quant = torch.ops._C.rms_norm_dynamic_per_token_quant.default # noqa: E501 fp8_quant = torch.ops._C.dynamic_scaled_fp8_quant.default # In pre-nodes, fp8 quant should be present and fused kernels should not diff --git a/tests/kernels/test_fused_quant_layernorm.py b/tests/kernels/test_fused_quant_layernorm.py index ec3ad5ab5dd8b..15015063658ab 100644 --- a/tests/kernels/test_fused_quant_layernorm.py +++ b/tests/kernels/test_fused_quant_layernorm.py @@ -8,8 +8,8 @@ DTYPES = [torch.bfloat16, torch.float] QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn] -NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing -HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192, +NUM_TOKENS = [1, 7, 83, 2048, 4096] # Arbitrary values for testing +HIDDEN_SIZES = [1, 2, 3, 4, 16, 64, 67, 768, 2048, 5120, 5137, 8192, 8193] # Arbitrary values for testing HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases ADD_RESIDUAL = [False, True] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 8503452df15df..bed3dad57c580 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -272,19 +272,14 @@ def rms_norm_dynamic_per_token_quant( # TODO is this necessary? @register_fake("_C::rms_norm_dynamic_per_token_quant") def _rms_norm_dynamic_per_token_quant_fake( - input: torch.Tensor, - weight: torch.Tensor, - epsilon: float, - quant_dtype: torch.dtype, - scale_ub: Optional[torch.Tensor] = None, - residual: Optional[torch.Tensor] = None -) -> Tuple[torch.Tensor, torch.Tensor]: - output = torch.empty_like(input, dtype=quant_dtype) - scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - - return output, scales + output: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scales: torch.Tensor, + epsilon: float, + scale_ub: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None) -> None: + return None # quantization ops diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 1c87c7771b21d..e4661d552931d 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -60,19 +60,15 @@ def __call__(self, graph: torch.fx.Graph): elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501 mutated_args = {1: 'result', 2: 'residual'} self.defunctionalize(graph, node, mutated_args) - elif at_target == torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default: # noqa: E501 - mutated_args = {1: 'result', 2: 'residual', 3: 'scale'} + elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501 + mutated_args = {1: 'result', 2: 'scale', 3: 'residual'} self.defunctionalize(graph, node, mutated_args) - elif at_target in [ torch.ops._C.rms_norm.default, torch.ops._C.rms_norm_static_fp8_quant.default ]: mutated_args = {1: 'result'} self.defunctionalize(graph, node, mutated_args) - elif at_target == torch.ops._C.rms_norm_dynamic_fp8_quant.default: - mutated_args = {1: 'result', 2: 'scale'} - self.defunctionalize(graph, node, mutated_args) elif at_target == torch.ops._C.silu_and_mul.default: mutated_args = {1: 'out'} diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index fa9b02ff13ea6..907f9ad2c8a7b 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -16,40 +16,6 @@ logger = init_logger(__name__) -# TODO temp -@torch.library.custom_op("_C::rms_norm_dynamic_fp8_quant", - mutates_args=("result", "scale")) -def rms_norm_dynamic_fp8_quant(result: torch.Tensor, input: torch.Tensor, - weight: torch.Tensor, scale: torch.Tensor, - epsilon: float) -> None: - # Last two are scale_ub, residual - torch.ops._C.rms_norm_dynamic_per_token_quant(result, input, weight, scale, epsilon, None, None) - - -@torch.library.register_fake("_C::rms_norm_dynamic_fp8_quant") -def _(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor, epsilon: float): - return None - - -@torch.library.custom_op("_C::fused_add_rms_norm_dynamic_fp8_quant", - mutates_args=("result", "residual", "scale")) -def fused_add_rms_norm_dynamic_fp8_quant(result: torch.Tensor, - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - epsilon: float) -> None: - # Last two are scale_ub, residual - torch.ops._C.rms_norm_dynamic_per_token_quant(result, input, weight, scale, epsilon, None, residual) - - -@torch.library.register_fake("_C::rms_norm_dynamic_fp8_quant") -def _(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor, epsilon: float): - return None - - def empty_bf16(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") @@ -372,12 +338,14 @@ def replacement(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): at = auto_functionalized( - torch.ops._C.rms_norm_static_fp8_quant.default, + torch.ops._C.rms_norm_dynamic_per_token_quant.default, result=result, input=input, weight=weight, scale=scale, - epsilon=self.epsilon) + epsilon=self.epsilon, + scale_ub=None, + residual=None) # result, scale return at[1], at[2] @@ -413,7 +381,7 @@ def process(self): # The auto_fn node returns a tuple of (None, result, scale). # # The resulting graph looks like this: - # at = auto_functionalized(torch.ops._C.rms_norm_static_fp8_quant.default, ...) # noqa + # at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa # result_node_new = at[1] # scale_node_new = at[2] with self.inserting_after_match(): @@ -421,10 +389,12 @@ def process(self): # Scalars cannot be inputs to the pattern kwargs["epsilon"] = rms_node.kwargs["epsilon"] + kwargs["scale_ub"] = None # not used but required + kwargs["residual"] = None # not used but required del kwargs["result_rms"] # not used in the fused op fused_node = self.insert_auto_fn( - torch.ops._C.rms_norm_dynamic_fp8_quant.default, + torch.ops._C.rms_norm_dynamic_per_token_quant.default, kwargs=kwargs) getitem_nodes = self.insert_getitems(fused_node, (1, 2)) @@ -466,16 +436,17 @@ def replacement(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): at = auto_functionalized( - torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default, + torch.ops._C.rms_norm_dynamic_per_token_quant.default, result=result, input=input, - residual=residual, weight=weight, scale=scale, - epsilon=self.epsilon) + epsilon=self.epsilon, + scale_ub=None, + residual=residual) # result, residual, scale - return at[1], at[2], at[3] # TODO confirm signature + return at[1], at[3], at[2] inputs = [ empty_fp8(5, 4), # result @@ -508,22 +479,23 @@ def process(self): # The auto_fn node returns a tuple (None, result, scale, residual). # # The resulting graph looks like this: - # at = auto_functionalized(torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default, ...) # noqa + # at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa # result_node_new = at[1] - # residual_node_new = at[2] - # scale_node_new = at[3] + # scale_node_new = at[2] + # residual_node_new = at[3] with self.inserting_after_match(): kwargs = self.match.kwargs.copy() # Scalars cannot be inputs to the pattern kwargs["epsilon"] = rms_node.kwargs["epsilon"] + kwargs["scale_ub"] = None # not used but required fused_node = self.insert_auto_fn( - torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default, + torch.ops._C.rms_norm_dynamic_per_token_quant.default, kwargs=kwargs) getitem_ns = self.insert_getitems(fused_node, (1, 2, 3)) - result_node_new, residual_node_new, scale_node_new = getitem_ns + result_node_new, scale_node_new, residual_node_new = getitem_ns # Rebind the users of match getitem nodes to use the new nodes. # The old nodes will be removed by DCE at the end of the pass. @@ -588,12 +560,12 @@ def __init__(self, config: CompilationConfig.PassConfig): self.patterns, self.record_match) # Fuse rms_norm + dynamic_scaled_fp8_quant into - # rms_norm_dynamic_fp8_quant + # rms_norm_dynamic_per_token_quant RMSNormDynamicFP8QuantPattern(epsilon).register( self.patterns, self.record_match) # Fuse fused_add_rms_norm + dynamic_scaled_fp8_quant into - # fused_add_rms_norm_dynamic_fp8_quant + # rms_norm_dynamic_per_token_quant FusedAddRMSNormDynamicFP8QuantPattern(epsilon).register( self.patterns, self.record_match) From 2a17c5d5c24ad448b857e2c90f1c03f59e1b0148 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 3 Dec 2024 18:50:14 +0000 Subject: [PATCH 10/21] Nit comment Signed-off-by: luka --- .../fused_kernels/fused_layernorm_dynamic_per_token_quant.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index e4c9d66c33b40..7bb406458983b 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -36,7 +36,7 @@ __device__ void rms_norm_dynamic_per_token_quant_vec( has_residual>( out, input, weight, rms, 1.0f / token_scale, hidden_size, residual); } else { - // FP8 - Do not invert s_token_scale for exact match with FBGemm + // FP8 - Do not invert token_scale for exact match with FBGemm vllm::vectorized::norm_and_quant( out, input, weight, rms, token_scale, hidden_size, residual); From 260443e2317bc06b870a6b014c8094b1802db995 Mon Sep 17 00:00:00 2001 From: luka Date: Wed, 4 Dec 2024 22:37:46 +0000 Subject: [PATCH 11/21] Fix fusion and tests to use dynamic per-token Signed-off-by: luka --- tests/compile/test_fusion.py | 31 +++++++++++++++++++++---------- vllm/_custom_ops.py | 1 - vllm/compilation/fusion.py | 12 +++++++----- 3 files changed, 28 insertions(+), 16 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index fa1765d6ad84a..6dc989f0d634c 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -31,14 +31,22 @@ def __init__(self, hidden_size: int, eps: float, static: bool, *args, ] def forward(self, x): - resid = torch.relu(x) + resid = torch.sqrt(x) y = self.norm[0](x) - x2 = apply_fp8_linear(y, self.w[0], self.wscale[0], self.scale[0]) + x2 = apply_fp8_linear(y, + self.w[0], + self.wscale[0], + self.scale[0], + use_per_token_if_dynamic=True) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - x3 = apply_fp8_linear(y2, self.w[1], self.wscale[1], self.scale[1]) + x3 = apply_fp8_linear(y2, + self.w[1], + self.wscale[1], + self.scale[1], + use_per_token_if_dynamic=True) y3, resid = self.norm[2](x3, resid) # use resid here return y3 @@ -75,12 +83,15 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static): model2 = torch.compile(model, backend=backend) result2 = model2(x) - # Check that it gives the same answer, higher tol for dynamic - ATOL, RTOL = (1e-3, 1e-3) if static else (2e-2, 2e-2) - torch.testing.assert_close(result.to(dtype=torch.float32), - result2.to(dtype=torch.float32), - atol=ATOL, - rtol=RTOL) + # Higher tol for dynamic, even higher for bfloat16 + if static: + ATOL, RTOL = (1e-3, 1e-3) + elif dtype == torch.float16: + ATOL, RTOL = (2e-3, 2e-3) + else: + ATOL, RTOL = (1e-2, 1e-2) + + torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) # Check substitution worked pre_nodes = backend.graph_pre_pass.nodes @@ -93,7 +104,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static): else: rms_quant = torch.ops._C.rms_norm_dynamic_per_token_quant.default add_rms_quant = torch.ops._C.rms_norm_dynamic_per_token_quant.default # noqa: E501 - fp8_quant = torch.ops._C.dynamic_scaled_fp8_quant.default + fp8_quant = torch.ops._C.dynamic_per_token_scaled_fp8_quant.default # In pre-nodes, fp8 quant should be present and fused kernels should not assert find_auto_fn_maybe(pre_nodes, rms_quant) is None diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index bed3dad57c580..3808fb9a87e56 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -269,7 +269,6 @@ def rms_norm_dynamic_per_token_quant( return output, scales -# TODO is this necessary? @register_fake("_C::rms_norm_dynamic_per_token_quant") def _rms_norm_dynamic_per_token_quant_fake( output: torch.Tensor, diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 907f9ad2c8a7b..823e66867f28a 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -163,7 +163,7 @@ def insert_auto_fn(self, op, kwargs): RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default QUANT_STATIC_FP8_OP = torch.ops._C.static_scaled_fp8_quant.default -QUANT_DYNAMIC_FP8_OP = torch.ops._C.dynamic_scaled_fp8_quant.default +QUANT_DYNAMIC_FP8_OP = torch.ops._C.dynamic_per_token_scaled_fp8_quant.default class RMSNormQuantPattern: @@ -329,7 +329,8 @@ def pattern(result: torch.Tensor, result_rms: torch.Tensor, at2 = auto_functionalized(QUANT_DYNAMIC_FP8_OP, result=result, input=at1[1], - scale=scale) + scale=scale, + scale_ub=None) # result, scale return at2[1], at2[2] @@ -427,7 +428,8 @@ def pattern(result: torch.Tensor, input: torch.Tensor, at1 = auto_functionalized(QUANT_DYNAMIC_FP8_OP, result=result, input=at[1], - scale=scale) + scale=scale, + scale_ub=None) # result, residual, scale return at1[1], at[2], at1[2] @@ -559,12 +561,12 @@ def __init__(self, config: CompilationConfig.PassConfig): FusedAddRMSNormStaticFP8QuantPattern(epsilon).register( self.patterns, self.record_match) - # Fuse rms_norm + dynamic_scaled_fp8_quant into + # Fuse rms_norm + dynamic_per_token_scaled_fp8_quant into # rms_norm_dynamic_per_token_quant RMSNormDynamicFP8QuantPattern(epsilon).register( self.patterns, self.record_match) - # Fuse fused_add_rms_norm + dynamic_scaled_fp8_quant into + # Fuse fused_add_rms_norm + dynamic_per_token_scaled_fp8_quant into # rms_norm_dynamic_per_token_quant FusedAddRMSNormDynamicFP8QuantPattern(epsilon).register( self.patterns, self.record_match) From 69d8cfc16e91835bbaac345444bc7095efb6ba11 Mon Sep 17 00:00:00 2001 From: luka Date: Wed, 4 Dec 2024 22:39:18 +0000 Subject: [PATCH 12/21] Remove debug graph output Signed-off-by: luka --- tests/compile/test_fusion.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 6dc989f0d634c..6c13be791f2ea 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -64,10 +64,8 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static): torch.manual_seed(1) # Reshape pass is needed for the fusion pass to work - config = CompilationConfig.PassConfig( - enable_fusion=True, - enable_reshape=True, - dump_graph_stages=["before_fusion", "after_fusion"]) + config = CompilationConfig.PassConfig(enable_fusion=True, + enable_reshape=True) reshape_pass = RedundantReshapesPass(config) fusion_pass = FusionPass.instance(config) From e07e032ce9f5344b543719246e2bd6315d92b641 Mon Sep 17 00:00:00 2001 From: luka Date: Wed, 4 Dec 2024 22:43:24 +0000 Subject: [PATCH 13/21] PR comments Signed-off-by: luka --- tests/kernels/test_fused_quant_layernorm.py | 4 ++-- vllm/_custom_ops.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_fused_quant_layernorm.py b/tests/kernels/test_fused_quant_layernorm.py index 15015063658ab..3997f4e9b8fe9 100644 --- a/tests/kernels/test_fused_quant_layernorm.py +++ b/tests/kernels/test_fused_quant_layernorm.py @@ -9,8 +9,8 @@ DTYPES = [torch.bfloat16, torch.float] QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn] NUM_TOKENS = [1, 7, 83, 2048, 4096] # Arbitrary values for testing -HIDDEN_SIZES = [1, 2, 3, 4, 16, 64, 67, 768, 2048, 5120, 5137, 8192, - 8193] # Arbitrary values for testing +HIDDEN_SIZES = [1, 3, 4, 16, 64, 2048, 5120, + 5137] # Arbitrary values for testing HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases ADD_RESIDUAL = [False, True] SCALE_UBS = [True, False] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3808fb9a87e56..8d5dfebc4c03b 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -22,7 +22,6 @@ supports_moe_ops = False with contextlib.suppress(ImportError): import vllm._moe_C # noqa: F401 - supports_moe_ops = True # neuron has torch version that doesn't even have impl_abstract @@ -242,6 +241,7 @@ def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int, paged_kv_indptr: torch.Tensor, paged_kv_last_page_len: torch.Tensor, block_table_bound: torch.Tensor) -> None: + return torch.ops._C.advance_step_flashinfer( num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, input_positions, seq_lens, slot_mapping, block_tables, @@ -737,7 +737,7 @@ def scaled_fp8_quant( shape: Union[Tuple[int, int], torch.Size] = input.shape # For rocm, the output fp8 dtype is torch.float_e3m3fnuz out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - if current_platform.is_rocm() else torch.float8_e4m3fn + if current_platform.is_rocm() else torch.float8_e4m3fn if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) output = torch.empty(shape, device=input.device, dtype=out_dtype) @@ -1020,9 +1020,9 @@ def register_graph_buffers(fa: int, handles: List[List[int]], # the case when users use `import __annotations__` to turn type # hints into strings. if isinstance(v, fn_type) \ - and v.__code__.co_filename == __file__ \ - and any(arg is torch.Tensor or arg == "torch.Tensor" - for arg in v.__annotations__.values()): + and v.__code__.co_filename == __file__ \ + and any(arg is torch.Tensor or arg == "torch.Tensor" + for arg in v.__annotations__.values()): names_and_values_to_update[k] = hint_on_error(v) names_and_values.update(names_and_values_to_update) From 99d0e21757f20e52f3d731b5539d2fdc794399c0 Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 5 Dec 2024 00:19:59 +0000 Subject: [PATCH 14/21] Fix TPU test Signed-off-by: luka --- vllm/_custom_ops.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 8d5dfebc4c03b..1964b934e1986 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -269,16 +269,18 @@ def rms_norm_dynamic_per_token_quant( return output, scales -@register_fake("_C::rms_norm_dynamic_per_token_quant") -def _rms_norm_dynamic_per_token_quant_fake( - output: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - scales: torch.Tensor, - epsilon: float, - scale_ub: Optional[torch.Tensor] = None, - residual: Optional[torch.Tensor] = None) -> None: - return None +if hasattr(torch.ops._C, "rms_norm_dynamic_per_token_quant"): + + @register_fake("_C::rms_norm_dynamic_per_token_quant") + def _rms_norm_dynamic_per_token_quant_fake( + output: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scales: torch.Tensor, + epsilon: float, + scale_ub: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None) -> None: + return None # quantization ops From aa4d86c7f0423273939ec02f25a69754ab798f5f Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 5 Dec 2024 00:31:39 +0000 Subject: [PATCH 15/21] Abstract out quantization type to allow more quant types more easily Signed-off-by: luka --- vllm/compilation/fusion.py | 270 ++++++++++++++++++++++++------------- 1 file changed, 179 insertions(+), 91 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 823e66867f28a..7f9c687fa4cd6 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -5,6 +5,8 @@ import torch import torch._inductor.pattern_matcher as pm +# TODO(luka) use vllm.utils once #10836 landed +from compressed_tensors.quantization import FP8_DTYPE from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass @@ -20,11 +22,6 @@ def empty_bf16(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") -def empty_fp8(*args, **kwargs): - fp8 = torch.float8_e4m3fn - return torch.empty(*args, **kwargs, dtype=fp8, device="cuda") - - def empty_fp32(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") @@ -162,17 +159,79 @@ def insert_auto_fn(self, op, kwargs): RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default -QUANT_STATIC_FP8_OP = torch.ops._C.static_scaled_fp8_quant.default -QUANT_DYNAMIC_FP8_OP = torch.ops._C.dynamic_per_token_scaled_fp8_quant.default +# Key: (fp8/int8, static/dynamic, per-tensor/per-token, symmetric/asymmetric) +# Value: the torch op +QUANT_OPS = { + (FP8_DTYPE, True, True, True): + torch.ops._C.static_scaled_fp8_quant.default, + (FP8_DTYPE, False, True, True): + torch.ops._C.dynamic_scaled_fp8_quant.default, + (FP8_DTYPE, False, False, True): + torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, +} + +# Key: (quant_key, fused_add) +FUSED_OPS = { + ((FP8_DTYPE, True, True, True), False): + torch.ops._C.rms_norm_static_fp8_quant.default, + ((FP8_DTYPE, True, True, True), True): + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, + ((FP8_DTYPE, False, False, True), False): + torch.ops._C.rms_norm_dynamic_per_token_quant.default, + ((FP8_DTYPE, False, False, True), True): + torch.ops._C.rms_norm_dynamic_per_token_quant.default, +} class RMSNormQuantPattern: - def __init__(self, epsilon: float): + def __init__(self, + epsilon: float, + fused_add: bool, + quant_dtype: torch.dtype, + static: bool, + per_tensor: bool = True, + symmetric=True): self.epsilon = epsilon + self.quant_dtype = quant_dtype + + # nicer assert + keystr = lambda: ( + f"({'static' if static else 'dynamic'}, {quant_dtype}, " + f"{'per_tensor' if per_tensor else 'per_token'}, " + f"{'a' if not symmetric else ''}symmetric)") + + key = (quant_dtype, static, per_tensor, symmetric) + assert key in QUANT_OPS, f"unsupported quantization scheme {keystr()}" + self.QUANT_OP = QUANT_OPS[key] + + key2 = (key, fused_add) + assert key2 in FUSED_OPS, ( + f"unsupported fused rmsnorm+quant op with" + f"{'out' if not fused_add else ''} residual)" + f" for quant scheme {keystr()})") + self.FUSED_OP = FUSED_OPS[key2] + + class Match(MultiOutputMatch): + + def __init__(self, match: pm.Match, quant_op, fused_op): + super().__init__(match) + self.QUANT_OP = quant_op + self.FUSED_OP = fused_op -class RMSNormStaticFP8QuantPattern(RMSNormQuantPattern): +class RMSNormStaticQuantPattern(RMSNormQuantPattern): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + symmetric=True): + super().__init__(epsilon, + fused_add=False, + quant_dtype=quant_dtype, + static=True, + per_tensor=True, + symmetric=symmetric) def register(self, pm_pass: PatternMatcherPass): # Cannot use methods, as the self argument affects tracing @@ -184,7 +243,7 @@ def pattern(result: torch.Tensor, result_rms: torch.Tensor, input=input, weight=weight, epsilon=self.epsilon) - at2 = auto_functionalized(QUANT_STATIC_FP8_OP, + at2 = auto_functionalized(self.QUANT_OP, result=result, input=at1[1], scale=scale) @@ -195,19 +254,18 @@ def pattern(result: torch.Tensor, result_rms: torch.Tensor, def replacement(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): - at = auto_functionalized( - torch.ops._C.rms_norm_static_fp8_quant.default, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=self.epsilon) + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon) # result return at[1] inputs = [ - empty_fp8(5, 4), # result + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result empty_bf16(5, 4), # result_rms empty_bf16(5, 4), # input empty_bf16(1, 5), # weight @@ -218,7 +276,18 @@ def replacement(result: torch.Tensor, result_rms: torch.Tensor, pm_pass) -class FusedAddRMSNormStaticFP8QuantPattern(RMSNormQuantPattern): +class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + symmetric=True): + super().__init__(epsilon, + fused_add=True, + quant_dtype=quant_dtype, + static=True, + per_tensor=True, + symmetric=symmetric) def register(self, pm_pass: PatternMatcherPass, record_match: Callable[[MultiOutputMatch], bool]): @@ -231,7 +300,7 @@ def pattern(result: torch.Tensor, input: torch.Tensor, residual=residual, weight=weight, epsilon=self.epsilon) - at1 = auto_functionalized(QUANT_STATIC_FP8_OP, + at1 = auto_functionalized(self.QUANT_OP, result=result, input=at[1], scale=scale) @@ -242,20 +311,19 @@ def pattern(result: torch.Tensor, input: torch.Tensor, def replacement(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): - at = auto_functionalized( - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - result=result, - input=input, - residual=residual, - weight=weight, - scale=scale, - epsilon=self.epsilon) + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + residual=residual, + weight=weight, + scale=scale, + epsilon=self.epsilon) # result, residual return at[1], at[2] inputs = [ - empty_fp8(5, 4), # result + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result empty_bf16(5, 4), # input empty_bf16(5, 4), # residual empty_bf16(1, 5), # weight @@ -268,14 +336,15 @@ def replacement(result: torch.Tensor, input: torch.Tensor, inputs, pm.fwd_only, pm_pass, - extra_check=lambda m: record_match(self.Match(m))) + extra_check=lambda m: record_match( + self.Match(m, self.QUANT_OP, self.FUSED_OP))) - class Match(MultiOutputMatch): + class Match(RMSNormQuantPattern.Match): def process(self): # Find the nodes in the match that we need to rebind rms_node = self.find_auto_fn(RMS_ADD_OP) - quant_node = self.find_auto_fn(QUANT_STATIC_FP8_OP) + quant_node = self.find_auto_fn(self.QUANT_OP) assert len(rms_node.users) == 2 assert len(quant_node.users) == 1 @@ -294,10 +363,8 @@ def process(self): # Scalars cannot be inputs to the pattern kwargs["epsilon"] = rms_node.kwargs["epsilon"] - fused_node = self.insert_auto_fn( - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - kwargs) - + # TODO simplify + fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs) getitem_nodes = self.insert_getitems(fused_node, (1, 2)) result_node_new, residual_node_new = getitem_nodes @@ -313,7 +380,19 @@ def process(self): fused_node.meta["val"] = (None, quant_tup[1], rms_tup[2]) -class RMSNormDynamicFP8QuantPattern(RMSNormQuantPattern): +class RMSNormDynamicQuantPattern(RMSNormQuantPattern): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + per_tensor: bool, + symmetric=True): + super().__init__(epsilon, + fused_add=False, + quant_dtype=quant_dtype, + static=False, + per_tensor=per_tensor, + symmetric=symmetric) def register(self, pm_pass: PatternMatcherPass, record_match: Callable[[MultiOutputMatch], bool]): @@ -326,7 +405,7 @@ def pattern(result: torch.Tensor, result_rms: torch.Tensor, input=input, weight=weight, epsilon=self.epsilon) - at2 = auto_functionalized(QUANT_DYNAMIC_FP8_OP, + at2 = auto_functionalized(self.QUANT_OP, result=result, input=at1[1], scale=scale, @@ -338,21 +417,20 @@ def pattern(result: torch.Tensor, result_rms: torch.Tensor, def replacement(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): - at = auto_functionalized( - torch.ops._C.rms_norm_dynamic_per_token_quant.default, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=self.epsilon, - scale_ub=None, - residual=None) + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=None) # result, scale return at[1], at[2] inputs = [ - empty_fp8(5, 4), # result + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result empty_bf16(5, 4), # result_rms empty_bf16(5, 4), # input empty_bf16(1, 5), # weight @@ -365,14 +443,15 @@ def replacement(result: torch.Tensor, result_rms: torch.Tensor, inputs, pm.fwd_only, pm_pass, - extra_check=lambda m: record_match(self.Match(m))) + extra_check=lambda m: record_match( + self.Match(m, self.QUANT_OP, self.FUSED_OP))) - class Match(MultiOutputMatch): + class Match(RMSNormQuantPattern.Match): def process(self): # Find the nodes in the match that we need to rebind rms_node = self.find_auto_fn(RMS_OP) - quant_node = self.find_auto_fn(QUANT_DYNAMIC_FP8_OP) + quant_node = self.find_auto_fn(self.QUANT_OP) assert len(rms_node.users) == 1 assert len(quant_node.users) == 2 @@ -394,10 +473,8 @@ def process(self): kwargs["residual"] = None # not used but required del kwargs["result_rms"] # not used in the fused op - fused_node = self.insert_auto_fn( - torch.ops._C.rms_norm_dynamic_per_token_quant.default, - kwargs=kwargs) - + # TODO simplify + fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs=kwargs) getitem_nodes = self.insert_getitems(fused_node, (1, 2)) result_node_new, scale_node_new = getitem_nodes @@ -412,7 +489,19 @@ def process(self): fused_node.meta["val"] = quant_node.meta["val"] -class FusedAddRMSNormDynamicFP8QuantPattern(RMSNormQuantPattern): +class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + per_tensor: bool = True, + symmetric=True): + super().__init__(epsilon, + fused_add=True, + quant_dtype=quant_dtype, + static=False, + per_tensor=per_tensor, + symmetric=symmetric) def register(self, pm_pass: PatternMatcherPass, record_match: Callable[[MultiOutputMatch], bool]): @@ -425,7 +514,7 @@ def pattern(result: torch.Tensor, input: torch.Tensor, residual=residual, weight=weight, epsilon=self.epsilon) - at1 = auto_functionalized(QUANT_DYNAMIC_FP8_OP, + at1 = auto_functionalized(self.QUANT_OP, result=result, input=at[1], scale=scale, @@ -437,21 +526,20 @@ def pattern(result: torch.Tensor, input: torch.Tensor, def replacement(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): - at = auto_functionalized( - torch.ops._C.rms_norm_dynamic_per_token_quant.default, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=self.epsilon, - scale_ub=None, - residual=residual) + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=residual) # result, residual, scale return at[1], at[3], at[2] inputs = [ - empty_fp8(5, 4), # result + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result empty_bf16(5, 4), # input empty_bf16(5, 4), # residual empty_bf16(1, 5), # weight @@ -464,14 +552,15 @@ def replacement(result: torch.Tensor, input: torch.Tensor, inputs, pm.fwd_only, pm_pass, - extra_check=lambda m: record_match(self.Match(m))) + extra_check=lambda m: record_match( + self.Match(m, self.QUANT_OP, self.FUSED_OP))) - class Match(MultiOutputMatch): + class Match(RMSNormQuantPattern.Match): def process(self): # Find the nodes in the match that we need to rebind rms_node = self.find_auto_fn(RMS_ADD_OP) - quant_node = self.find_auto_fn(QUANT_DYNAMIC_FP8_OP) + quant_node = self.find_auto_fn(self.QUANT_OP) assert len(rms_node.users) == 2 assert len(quant_node.users) == 2 @@ -492,10 +581,7 @@ def process(self): kwargs["epsilon"] = rms_node.kwargs["epsilon"] kwargs["scale_ub"] = None # not used but required - fused_node = self.insert_auto_fn( - torch.ops._C.rms_norm_dynamic_per_token_quant.default, - kwargs=kwargs) - + fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs=kwargs) getitem_ns = self.insert_getitems(fused_node, (1, 2, 3)) result_node_new, scale_node_new, residual_node_new = getitem_ns @@ -550,27 +636,29 @@ def __init__(self, config: CompilationConfig.PassConfig): pass_name="fusion_pass") for epsilon in [1e-5, 1e-6]: - # Fuse rms_norm + static_scaled_fp8_quant into - # rms_norm_static_fp8_quant - RMSNormStaticFP8QuantPattern(epsilon).register(self.patterns) - - # Fuse fused_add_rms_norm + static_scaled_fp8_quant into - # fused_add_rms_norm_static_fp8_quant - # Because pattern has 2 outputs, we need to manually process - # the match (see process_matches) - FusedAddRMSNormStaticFP8QuantPattern(epsilon).register( - self.patterns, self.record_match) + # Fuse rms_norm + static fp8 quant + RMSNormStaticQuantPattern(epsilon, + FP8_DTYPE).register(self.patterns) - # Fuse rms_norm + dynamic_per_token_scaled_fp8_quant into - # rms_norm_dynamic_per_token_quant - RMSNormDynamicFP8QuantPattern(epsilon).register( - self.patterns, self.record_match) + # Matches for patterns below have 2 or more outputs, + # so we need to process them manually (see process_matches) - # Fuse fused_add_rms_norm + dynamic_per_token_scaled_fp8_quant into - # rms_norm_dynamic_per_token_quant - FusedAddRMSNormDynamicFP8QuantPattern(epsilon).register( + # Fuse rms_norm + static fp8 quant + FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( self.patterns, self.record_match) + # Fuse rms_norm + dynamic per-token fp8 quant + RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE, + per_tensor=False).register( + self.patterns, self.record_match) + + # Fuse fused_add_rms_norm + dynamic per-token fp8 quant + FusedAddRMSNormDynamicQuantPattern(epsilon, + FP8_DTYPE, + per_tensor=False).register( + self.patterns, + self.record_match) + # WARNING: This is a hack to clear the pattern matcher cache # and allow multiple values of epsilon. torch._inductor.pattern_matcher._seen_patterns.clear() From 2e6a0cb462d127b220696c781f546534b715daad Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 5 Dec 2024 01:26:46 +0000 Subject: [PATCH 16/21] Extract getitem replacement and meta value fixing into insert_fused_node Signed-off-by: luka --- vllm/compilation/fusion.py | 183 ++++++++++++++++++++----------------- 1 file changed, 101 insertions(+), 82 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 7f9c687fa4cd6..0a2b1c2297c1e 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -1,12 +1,13 @@ import abc import operator from abc import abstractmethod -from typing import Callable, Iterable, List, Optional, Tuple +from typing import Callable, Dict, Iterable, List, Optional, Tuple import torch import torch._inductor.pattern_matcher as pm # TODO(luka) use vllm.utils once #10836 landed from compressed_tensors.quantization import FP8_DTYPE +from torch import fx from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass @@ -30,8 +31,7 @@ def empty_fp32(*args, **kwargs): # Returns the first auto_functionalized node with the given op (if it exists) -def find_auto_fn_maybe(nodes: Iterable[torch.fx.Node], - op) -> Optional[torch.fx.Node]: +def find_auto_fn_maybe(nodes: Iterable[fx.Node], op) -> Optional[fx.Node]: for node in nodes: if is_func(node, auto_functionalized) and node.args[0] == op: # noqa return node @@ -39,7 +39,7 @@ def find_auto_fn_maybe(nodes: Iterable[torch.fx.Node], # Returns the first auto_functionalized node with the given op -def find_auto_fn(nodes: Iterable[torch.fx.Node], op) -> torch.fx.Node: +def find_auto_fn(nodes: Iterable[fx.Node], op) -> fx.Node: node = find_auto_fn_maybe(nodes, op) assert node is not None, f"Could not find {op} in nodes {nodes}" return node @@ -47,8 +47,7 @@ def find_auto_fn(nodes: Iterable[torch.fx.Node], op) -> torch.fx.Node: # Returns the getitem node that extracts the idx-th element from node # (if it exists) -def find_getitem_maybe(node: torch.fx.Node, - idx: int) -> Optional[torch.fx.Node]: +def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]: for user in node.users: if is_func(user, operator.getitem) and user.args[1] == idx: return user @@ -56,7 +55,7 @@ def find_getitem_maybe(node: torch.fx.Node, # Returns the getitem node that extracts the idx-th element from node -def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node: +def find_getitem(node: fx.Node, idx: int) -> fx.Node: ret = find_getitem_maybe(node, idx) assert ret is not None, f"Could not find getitem {idx} in node {node}" return ret @@ -104,14 +103,14 @@ def process(self): raise NotImplementedError @property - def nodes(self) -> List[torch.fx.Node]: + def nodes(self) -> List[fx.Node]: return self.match.nodes @property - def graph(self) -> torch.fx.Graph: + def graph(self) -> fx.Graph: return self.match.graph - def find_auto_fn(self, op) -> torch.fx.Node: + def find_auto_fn(self, op) -> fx.Node: """ Find the first auto_functionalized node with the given op in the match. """ @@ -134,8 +133,8 @@ def inserting_after_match(self): return self.graph.inserting_after(last_node_in_match) - def insert_getitems(self, tuple_node: torch.fx.Node, - indices: Tuple[int, ...]) -> Tuple[torch.fx.Node, ...]: + def insert_getitems(self, tuple_node: fx.Node, + indices: Iterable[int]) -> Tuple[fx.Node, ...]: """ Insert operator.getitem nodes to extract elements from a tuple node. @@ -160,7 +159,6 @@ def insert_auto_fn(self, op, kwargs): RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default # Key: (fp8/int8, static/dynamic, per-tensor/per-token, symmetric/asymmetric) -# Value: the torch op QUANT_OPS = { (FP8_DTYPE, True, True, True): torch.ops._C.static_scaled_fp8_quant.default, @@ -183,6 +181,66 @@ def insert_auto_fn(self, op, kwargs): } +class QuantMultiOutputMatch(MultiOutputMatch): + + def __init__(self, match: pm.Match, quant_op, fused_op): + super().__init__(match) + self.QUANT_OP = quant_op + self.FUSED_OP = fused_op + + def insert_fused_node(self, fused_return_mapping: Dict[int, Tuple[fx.Node, + int]], + **kwargs): + """ + This utility function inserts an auto-functionalized node for FUSED_OP. + It also correctly sets its meta value and rebinds the users of the + unfused nodes to use the fused node instead. + + :param fused_return_mapping: A dictionary, mapping from getitem indices + of the fused node result to a tuple of the old node and a getitem index. + :param kwargs: kwargs that get directly forwarded to the auto_fn node + + Example: + If we want to replace this graph: + _, x1, x2 = auto_fn(op1) + _, y1, y2 = auto_fn(op2) + + with + _, x1, y2, x2 = auto_fn(FUSED_OP) + + we would call: + insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)} + + Note that the 0th element is None for auto-functionalized in-place ops. + Hence others appear 1-indexed. + """ + fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs) + indices = fused_return_mapping.keys() + getitem_nodes = self.insert_getitems(fused_node, indices) + + # Prepare the meta value, use a list so it's mutable + meta_val = [None] * (max(indices) + 1) + + # Iterate through elements of the tuple produced by fused_node + for idx, getitem_node in zip(indices, getitem_nodes): + old_node, old_idx = fused_return_mapping[idx] + + # If the old value was never used, the old_getitem might not exist + old_getitem = find_getitem_maybe(old_node, old_idx) + if old_getitem is not None: + # Rebind the users of match getitem nodes to use the new nodes. + # The old nodes will be removed by DCE at the end of the pass. + old_getitem.replace_all_uses_with(getitem_node) + getitem_node.meta["val"] = old_getitem.meta["val"] + + # Extract the appropriate meta value + # It is present even if the getitem node does not exist + meta_val[idx] = old_node.meta["val"][old_idx] + + # Fix the meta value on the new fused node + fused_node.meta["val"] = tuple(meta_val) + + class RMSNormQuantPattern: def __init__(self, @@ -212,13 +270,6 @@ def __init__(self, f" for quant scheme {keystr()})") self.FUSED_OP = FUSED_OPS[key2] - class Match(MultiOutputMatch): - - def __init__(self, match: pm.Match, quant_op, fused_op): - super().__init__(match) - self.QUANT_OP = quant_op - self.FUSED_OP = fused_op - class RMSNormStaticQuantPattern(RMSNormQuantPattern): @@ -339,7 +390,7 @@ def replacement(result: torch.Tensor, input: torch.Tensor, extra_check=lambda m: record_match( self.Match(m, self.QUANT_OP, self.FUSED_OP))) - class Match(RMSNormQuantPattern.Match): + class Match(QuantMultiOutputMatch): def process(self): # Find the nodes in the match that we need to rebind @@ -358,26 +409,14 @@ def process(self): # result_node_new = at[1] # residual_node_new = at[2] with self.inserting_after_match(): + # Missing epsilon, scalars cannot be inputs to the pattern kwargs = self.match.kwargs.copy() - # Scalars cannot be inputs to the pattern - kwargs["epsilon"] = rms_node.kwargs["epsilon"] - - # TODO simplify - fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs) - getitem_nodes = self.insert_getitems(fused_node, (1, 2)) - result_node_new, residual_node_new = getitem_nodes - - # Rebind the users of match getitem nodes to use the new nodes. - # The old nodes will be removed by DCE at the end of the pass. - find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) - find_getitem(quant_node, 1).replace_all_uses_with(result_node_new) - - # Finally, fix meta["val"] for de-functionalization. - # See MultiOutputMatch.process for more details. - rms_tup, quant_tup = rms_node.meta["val"], quant_node.meta["val"] - # Result of fused node is (None, result, residual) - fused_node.meta["val"] = (None, quant_tup[1], rms_tup[2]) + # 0 is always None + fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)} + self.insert_fused_node(fused_return_mapping, + epsilon=rms_node.kwargs["epsilon"], + **kwargs) class RMSNormDynamicQuantPattern(RMSNormQuantPattern): @@ -446,7 +485,7 @@ def replacement(result: torch.Tensor, result_rms: torch.Tensor, extra_check=lambda m: record_match( self.Match(m, self.QUANT_OP, self.FUSED_OP))) - class Match(RMSNormQuantPattern.Match): + class Match(QuantMultiOutputMatch): def process(self): # Find the nodes in the match that we need to rebind @@ -465,28 +504,17 @@ def process(self): # result_node_new = at[1] # scale_node_new = at[2] with self.inserting_after_match(): + # Missing epsilon, scalars cannot be inputs to the pattern kwargs = self.match.kwargs.copy() - - # Scalars cannot be inputs to the pattern - kwargs["epsilon"] = rms_node.kwargs["epsilon"] - kwargs["scale_ub"] = None # not used but required - kwargs["residual"] = None # not used but required del kwargs["result_rms"] # not used in the fused op - # TODO simplify - fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs=kwargs) - getitem_nodes = self.insert_getitems(fused_node, (1, 2)) - result_node_new, scale_node_new = getitem_nodes - - # Rebind the users of match getitem nodes to use the new nodes. - # The old nodes will be removed by DCE at the end of the pass. - find_getitem(quant_node, 1).replace_all_uses_with(result_node_new) - find_getitem(quant_node, 2).replace_all_uses_with(scale_node_new) - - # Finally, fix meta["val"] for de-functionalization. - # See MultiOutputMatch.process for more details. - # Result of fused node is (None, result, scale) - fused_node.meta["val"] = quant_node.meta["val"] + fused_return_mapping = {1: (quant_node, 1), 2: (quant_node, 2)} + self.insert_fused_node( + fused_return_mapping, + epsilon=rms_node.kwargs["epsilon"], + scale_ub=None, # not used but required + residual=None, # not used but required + **kwargs) class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): @@ -555,7 +583,7 @@ def replacement(result: torch.Tensor, input: torch.Tensor, extra_check=lambda m: record_match( self.Match(m, self.QUANT_OP, self.FUSED_OP))) - class Match(RMSNormQuantPattern.Match): + class Match(QuantMultiOutputMatch): def process(self): # Find the nodes in the match that we need to rebind @@ -575,28 +603,19 @@ def process(self): # scale_node_new = at[2] # residual_node_new = at[3] with self.inserting_after_match(): + # Missing epsilon, scalars cannot be inputs to the pattern kwargs = self.match.kwargs.copy() - # Scalars cannot be inputs to the pattern - kwargs["epsilon"] = rms_node.kwargs["epsilon"] - kwargs["scale_ub"] = None # not used but required - - fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs=kwargs) - getitem_ns = self.insert_getitems(fused_node, (1, 2, 3)) - result_node_new, scale_node_new, residual_node_new = getitem_ns - - # Rebind the users of match getitem nodes to use the new nodes. - # The old nodes will be removed by DCE at the end of the pass. - find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) - find_getitem(quant_node, 1).replace_all_uses_with(result_node_new) - find_getitem(quant_node, 2).replace_all_uses_with(scale_node_new) - - # Finally, fix meta["val"] for de-functionalization. - # See MultiOutputMatch.process for more details. - rms_tup, quant_tup = rms_node.meta["val"], quant_node.meta["val"] - # Result of fused node is (None, result, scale, residual) - fused_node.meta["val"] = (None, quant_tup[1], quant_tup[2], - rms_tup[2]) + fused_return_mapping = { + 1: (quant_node, 1), # result + 2: (quant_node, 2), # scale + 3: (rms_node, 2), # residual + } + self.insert_fused_node( + fused_return_mapping, + epsilon=rms_node.kwargs["epsilon"], + scale_ub=None, # not used but required + **kwargs) class FusionPass(VllmInductorPass): @@ -671,7 +690,7 @@ def record_match(self, match: MultiOutputMatch) -> bool: # Return False to prevent automatic replacement. return False - def process_matches(self, graph: torch.fx.Graph): + def process_matches(self, graph: fx.Graph): """ Manually process multi-output matches and replace them with fused nodes. See MultiOutputMatch for more details. @@ -684,7 +703,7 @@ def process_matches(self, graph: torch.fx.Graph): assert all(node not in graph.nodes for match in self.matches for node in match.match.nodes) - def __call__(self, graph: torch.fx.Graph): + def __call__(self, graph: fx.Graph): self.begin() self.dump_graph(graph, "before_fusion") From 9c9ea8c0fc13271c956e331dcae430c17119ae83 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 10 Dec 2024 21:02:43 +0000 Subject: [PATCH 17/21] - OpOverload typing - extracted MultiOutputMatch to own file - extracted utils to fx_utils - added named tuples for op keys Signed-off-by: luka --- tests/compile/test_functionalization.py | 21 +- tests/compile/test_fusion.py | 21 +- vllm/compilation/fix_functionalization.py | 3 +- vllm/compilation/fusion.py | 296 +++++++--------------- vllm/compilation/fx_utils.py | 42 +++ vllm/compilation/multi_output_match.py | 105 ++++++++ vllm/compilation/reshapes.py | 3 +- vllm/compilation/vllm_inductor_pass.py | 4 - 8 files changed, 271 insertions(+), 224 deletions(-) create mode 100644 vllm/compilation/fx_utils.py create mode 100644 vllm/compilation/multi_output_match.py diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 5036189077be2..ea3aaee9565ec 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -4,10 +4,10 @@ import vllm.envs as envs from vllm import LLM, SamplingParams from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.compilation.fusion import (FusionPass, find_auto_fn, - find_auto_fn_maybe) +from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey, + kFp8DynamicTokenSym, kFp8StaticTensorSym) +from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.reshapes import RedundantReshapesPass -from vllm.compilation.vllm_inductor_pass import is_func from vllm.config import CompilationConfig from .backend import TestBackend @@ -35,12 +35,16 @@ ] -@pytest.mark.parametrize("model", - ["nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"]) +@pytest.mark.parametrize( + "model, quant_key", + [("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e", kFp8StaticTensorSym), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e", + kFp8DynamicTokenSym)]) @pytest.mark.parametrize("do_fusion", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA") -def test_fix_functionalization(model: str, do_fusion: bool): +def test_fix_functionalization(model: str, quant_key: QuantKey, + do_fusion: bool): torch.set_default_device("cuda") config = CompilationConfig.PassConfig(enable_fusion=do_fusion, @@ -78,8 +82,9 @@ def test_fix_functionalization(model: str, do_fusion: bool): # OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion, # and replaced by fused quantized ops in RMS_QUANT_OPS. - ops = OPS_IN_MODEL + (RMS_QUANT_OPS["static_fp8"] - if do_fusion else [RMS_OP]) + rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)] + ] if do_fusion else [RMS_OP] + ops = OPS_IN_MODEL + rms_ops for op in ops: find_auto_fn(backend_no_func.graph_post_pass.nodes, op) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 6c13be791f2ea..b4266a4a7db94 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -3,8 +3,9 @@ from compressed_tensors.quantization import FP8_DTYPE import vllm.envs as envs -from vllm.compilation.fusion import (FusionPass, find_auto_fn, - find_auto_fn_maybe) +from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, + FusionPass, QuantKey) +from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe from vllm.compilation.reshapes import RedundantReshapesPass from vllm.config import CompilationConfig from vllm.model_executor.layers.layernorm import RMSNorm @@ -95,14 +96,14 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static): pre_nodes = backend.graph_pre_pass.nodes post_nodes = backend.graph_post_pass.nodes - if static: - rms_quant = torch.ops._C.rms_norm_static_fp8_quant.default - add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default # noqa: E501 - fp8_quant = torch.ops._C.static_scaled_fp8_quant.default - else: - rms_quant = torch.ops._C.rms_norm_dynamic_per_token_quant.default - add_rms_quant = torch.ops._C.rms_norm_dynamic_per_token_quant.default # noqa: E501 - fp8_quant = torch.ops._C.dynamic_per_token_scaled_fp8_quant.default + # static is per-tensor, dynamic is per-token + key = QuantKey(dtype=FP8_DTYPE, + static=static, + per_tensor=static, + symmetric=True) + rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)] + add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)] + fp8_quant = QUANT_OPS[key] # In pre-nodes, fp8 quant should be present and fused kernels should not assert find_auto_fn_maybe(pre_nodes, rms_quant) is None diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index e4661d552931d..e15d7b315c50f 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -6,7 +6,8 @@ from vllm.logger import init_logger -from .vllm_inductor_pass import VllmInductorPass, is_func +from .fx_utils import is_func +from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 0a2b1c2297c1e..cde27bd108212 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -1,7 +1,4 @@ -import abc -import operator -from abc import abstractmethod -from typing import Callable, Dict, Iterable, List, Optional, Tuple +from typing import Callable, Dict, List, NamedTuple, Optional, Tuple import torch import torch._inductor.pattern_matcher as pm @@ -10,11 +7,14 @@ from torch import fx from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass +from torch._ops import OpOverload from vllm.config import CompilationConfig from vllm.logger import init_logger -from .vllm_inductor_pass import VllmInductorPass, is_func +from .fx_utils import find_getitem_maybe +from .multi_output_match import MultiOutputMatch +from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) @@ -27,157 +27,66 @@ def empty_fp32(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") -# Utilities for post-processing multi-output matches - - -# Returns the first auto_functionalized node with the given op (if it exists) -def find_auto_fn_maybe(nodes: Iterable[fx.Node], op) -> Optional[fx.Node]: - for node in nodes: - if is_func(node, auto_functionalized) and node.args[0] == op: # noqa - return node - return None - - -# Returns the first auto_functionalized node with the given op -def find_auto_fn(nodes: Iterable[fx.Node], op) -> fx.Node: - node = find_auto_fn_maybe(nodes, op) - assert node is not None, f"Could not find {op} in nodes {nodes}" - return node - - -# Returns the getitem node that extracts the idx-th element from node -# (if it exists) -def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]: - for user in node.users: - if is_func(user, operator.getitem) and user.args[1] == idx: - return user - return None - - -# Returns the getitem node that extracts the idx-th element from node -def find_getitem(node: fx.Node, idx: int) -> fx.Node: - ret = find_getitem_maybe(node, idx) - assert ret is not None, f"Could not find getitem {idx} in node {node}" - return ret +RMS_OP = torch.ops._C.rms_norm.default +RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default -class MultiOutputMatch(abc.ABC): +class QuantKey(NamedTuple): """ - This class provides utilities to process multi-output matches and - manually insert replacements. - - This is necessary because the automatic replacement for multi-output - matches is broken: https://github.com/pytorch/pytorch/issues/137280 + Named tuple for identifying the type of quantization. + dtype: quantized data type + static: static quantization if True, dynamic if False + per_tensor: per-tensor quantization if True, per-token if False + symmetric: symmetric if True, asymmetric if False """ - - def __init__(self, match: pm.Match): - self.match = match - - @abstractmethod - def process(self): - """ - Process a multi-output match and manually insert the replacement. - - This method should: - 1. Insert the replacement nodes after the last node in the match. - 2. Rebind the users of nodes in the match to use the new nodes. - 3. Set meta["val"] for de-functionalization. - - The result of an auto-functionalized node is a tuple of tensors. - The first element is the return value of the function, usually None. - The remaining elements are the mutated args of the function. - - All auto-functionalized nodes must contain a proper meta["val"], - as it is used by de-functionalization. meta["val"] has to contain the - value of the node (tuple of tensors) that would be returned by the - functionalized node during tracing. - - Existing nodes in the graph all have this property set, but we have - to set it manually for new nodes we insert. - - Example: - # op schema: foo(a: Tensor!, b: Tensor, c: Tensor!) -> None - at = auto_functionalized(torch.ops._C.foo.default, a, b, c) - # at.meta["val"] = (None, a, c) - """ - raise NotImplementedError - - @property - def nodes(self) -> List[fx.Node]: - return self.match.nodes - - @property - def graph(self) -> fx.Graph: - return self.match.graph - - def find_auto_fn(self, op) -> fx.Node: - """ - Find the first auto_functionalized node with the given op in the match. - """ - return find_auto_fn(self.nodes, op) - - def inserting_after_match(self): - """ - Insert nodes after the last node in the match. - This is done to avoid use-before-definition errors after inserting - replacement nodes. - """ - - # match.nodes is not guaranteed to be sorted. - # Find the last node in the match. - for last_node_in_match in reversed(self.graph.nodes): - if last_node_in_match in self.match.nodes: - break - else: - raise ValueError("No nodes in graph") - - return self.graph.inserting_after(last_node_in_match) - - def insert_getitems(self, tuple_node: fx.Node, - indices: Iterable[int]) -> Tuple[fx.Node, ...]: - """ - Insert operator.getitem nodes to extract elements from a tuple node. - - :param tuple_node: The tuple node to extract elements from. - :param indices: The indices of the elements to extract. - :return: Tuple of the new getitem nodes, corresponding to the indices. - """ - with self.graph.inserting_after(tuple_node): - return tuple( - self.graph.call_function(operator.getitem, (tuple_node, idx)) - for idx in indices) - - def insert_auto_fn(self, op, kwargs): - """ - Insert an auto_functionalized node with the given op and kwargs. - """ - return self.graph.call_function(auto_functionalized, (op, ), - kwargs=kwargs) - - -RMS_OP = torch.ops._C.rms_norm.default -RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default - -# Key: (fp8/int8, static/dynamic, per-tensor/per-token, symmetric/asymmetric) -QUANT_OPS = { - (FP8_DTYPE, True, True, True): - torch.ops._C.static_scaled_fp8_quant.default, - (FP8_DTYPE, False, True, True): - torch.ops._C.dynamic_scaled_fp8_quant.default, - (FP8_DTYPE, False, False, True): - torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, + dtype: torch.dtype + static: bool + per_tensor: bool = True + symmetric: bool = True + + def __str__(self): + return (f"QuantKey({'static' if self.static else 'dynamic'}," + f"{fx.graph.dtype_abbrs[self.dtype]}," + f"{'per_tensor' if self.per_tensor else 'per_token'}," + f"{'a' if not self.symmetric else ''}symmetric)") + + +kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, True, True) +kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True) +kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True) + +QUANT_OPS: Dict[QuantKey, OpOverload] = { + kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa + kFp8DynamicTensorSym: + torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa + kFp8DynamicTokenSym: + torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa } -# Key: (quant_key, fused_add) -FUSED_OPS = { - ((FP8_DTYPE, True, True, True), False): - torch.ops._C.rms_norm_static_fp8_quant.default, - ((FP8_DTYPE, True, True, True), True): - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - ((FP8_DTYPE, False, False, True), False): - torch.ops._C.rms_norm_dynamic_per_token_quant.default, - ((FP8_DTYPE, False, False, True), True): - torch.ops._C.rms_norm_dynamic_per_token_quant.default, + +class FusedRMSQuantKey(NamedTuple): + """ + Named tuple for identifying the type of RMSNorm + quant fusion. + quant: type of quantization + fused_add: does the op also perform the residual add + """ + quant: QuantKey + fused_add: bool + + def __str__(self): + return (f"FusedQuantKey({self.quant}, with" + f"{'' if self.fused_add else 'out'} residual)") + + +FUSED_OPS: Dict[FusedRMSQuantKey, OpOverload] = { + FusedRMSQuantKey(kFp8StaticTensorSym, False): + torch.ops._C.rms_norm_static_fp8_quant.default, # noqa + FusedRMSQuantKey(kFp8StaticTensorSym, True): + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa + FusedRMSQuantKey(kFp8DynamicTokenSym, False): + torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa + FusedRMSQuantKey(kFp8DynamicTokenSym, True): + torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa } @@ -185,8 +94,10 @@ class QuantMultiOutputMatch(MultiOutputMatch): def __init__(self, match: pm.Match, quant_op, fused_op): super().__init__(match) - self.QUANT_OP = quant_op - self.FUSED_OP = fused_op + assert isinstance(quant_op, OpOverload) + assert isinstance(fused_op, OpOverload) + self.QUANT_OP = quant_op # in-place quant op + self.FUSED_OP = fused_op # in-place fused quant op def insert_fused_node(self, fused_return_mapping: Dict[int, Tuple[fx.Node, int]], @@ -212,7 +123,7 @@ def insert_fused_node(self, fused_return_mapping: Dict[int, Tuple[fx.Node, insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)} Note that the 0th element is None for auto-functionalized in-place ops. - Hence others appear 1-indexed. + Hence, others appear 1-indexed. """ fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs) indices = fused_return_mapping.keys() @@ -243,32 +154,17 @@ def insert_fused_node(self, fused_return_mapping: Dict[int, Tuple[fx.Node, class RMSNormQuantPattern: - def __init__(self, - epsilon: float, - fused_add: bool, - quant_dtype: torch.dtype, - static: bool, - per_tensor: bool = True, - symmetric=True): + def __init__(self, epsilon: float, key: FusedRMSQuantKey): self.epsilon = epsilon - self.quant_dtype = quant_dtype - - # nicer assert - keystr = lambda: ( - f"({'static' if static else 'dynamic'}, {quant_dtype}, " - f"{'per_tensor' if per_tensor else 'per_token'}, " - f"{'a' if not symmetric else ''}symmetric)") + self.quant_dtype = key.quant.dtype - key = (quant_dtype, static, per_tensor, symmetric) - assert key in QUANT_OPS, f"unsupported quantization scheme {keystr()}" - self.QUANT_OP = QUANT_OPS[key] + assert key.quant in QUANT_OPS, \ + f"unsupported quantization scheme {key.quant}" + self.QUANT_OP = QUANT_OPS[key.quant] - key2 = (key, fused_add) - assert key2 in FUSED_OPS, ( - f"unsupported fused rmsnorm+quant op with" - f"{'out' if not fused_add else ''} residual)" - f" for quant scheme {keystr()})") - self.FUSED_OP = FUSED_OPS[key2] + assert key in FUSED_OPS, \ + f"unsupported fused rmsnorm+quant op for {key}" + self.FUSED_OP = FUSED_OPS[key] class RMSNormStaticQuantPattern(RMSNormQuantPattern): @@ -277,12 +173,12 @@ def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): - super().__init__(epsilon, - fused_add=False, - quant_dtype=quant_dtype, - static=True, - per_tensor=True, - symmetric=symmetric) + fused_key = FusedRMSQuantKey(fused_add=False, + quant=QuantKey(dtype=quant_dtype, + static=True, + per_tensor=True, + symmetric=symmetric)) + super().__init__(epsilon, fused_key) def register(self, pm_pass: PatternMatcherPass): # Cannot use methods, as the self argument affects tracing @@ -333,12 +229,12 @@ def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): - super().__init__(epsilon, - fused_add=True, - quant_dtype=quant_dtype, - static=True, - per_tensor=True, - symmetric=symmetric) + key = FusedRMSQuantKey(fused_add=True, + quant=QuantKey(dtype=quant_dtype, + static=True, + per_tensor=True, + symmetric=symmetric)) + super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass, record_match: Callable[[MultiOutputMatch], bool]): @@ -426,12 +322,12 @@ def __init__(self, quant_dtype: torch.dtype, per_tensor: bool, symmetric=True): - super().__init__(epsilon, - fused_add=False, - quant_dtype=quant_dtype, - static=False, - per_tensor=per_tensor, - symmetric=symmetric) + key = FusedRMSQuantKey(fused_add=False, + quant=QuantKey(dtype=quant_dtype, + static=False, + per_tensor=per_tensor, + symmetric=symmetric)) + super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass, record_match: Callable[[MultiOutputMatch], bool]): @@ -524,12 +420,12 @@ def __init__(self, quant_dtype: torch.dtype, per_tensor: bool = True, symmetric=True): - super().__init__(epsilon, - fused_add=True, - quant_dtype=quant_dtype, - static=False, - per_tensor=per_tensor, - symmetric=symmetric) + key = FusedRMSQuantKey(fused_add=True, + quant=QuantKey(dtype=quant_dtype, + static=False, + per_tensor=per_tensor, + symmetric=symmetric)) + super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass, record_match: Callable[[MultiOutputMatch], bool]): diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/fx_utils.py new file mode 100644 index 0000000000000..924e26f2e262e --- /dev/null +++ b/vllm/compilation/fx_utils.py @@ -0,0 +1,42 @@ +import operator +from typing import Iterable, Optional + +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._ops import OpOverload + + +def is_func(node: fx.Node, target) -> bool: + return node.op == "call_function" and node.target == target + + +# Returns the first auto_functionalized node with the given op (if it exists) +def find_auto_fn_maybe(nodes: Iterable[fx.Node], + op: OpOverload) -> Optional[fx.Node]: + for node in nodes: + if is_func(node, auto_functionalized) and node.args[0] == op: # noqa + return node + return None + + +# Returns the first auto_functionalized node with the given op +def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node: + node = find_auto_fn_maybe(nodes, op) + assert node is not None, f"Could not find {op} in nodes {nodes}" + return node + + +# Returns the getitem node that extracts the idx-th element from node +# (if it exists) +def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]: + for user in node.users: + if is_func(user, operator.getitem) and user.args[1] == idx: + return user + return None + + +# Returns the getitem node that extracts the idx-th element from node +def find_getitem(node: fx.Node, idx: int) -> fx.Node: + ret = find_getitem_maybe(node, idx) + assert ret is not None, f"Could not find getitem {idx} in node {node}" + return ret diff --git a/vllm/compilation/multi_output_match.py b/vllm/compilation/multi_output_match.py new file mode 100644 index 0000000000000..0ad648abfbb3a --- /dev/null +++ b/vllm/compilation/multi_output_match.py @@ -0,0 +1,105 @@ +import abc +import operator +from abc import abstractmethod +from typing import Iterable, List, Tuple + +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor import pattern_matcher as pm +from torch._ops import OpOverload + +from vllm.compilation.fx_utils import find_auto_fn + + +class MultiOutputMatch(abc.ABC): + """ + This class provides utilities to process multi-output matches and + manually insert replacements. + + This is necessary because the automatic replacement for multi-output + matches is broken: https://github.com/pytorch/pytorch/issues/137280 + """ + + def __init__(self, match: pm.Match): + self.match = match + + @abstractmethod + def process(self): + """ + Process a multi-output match and manually insert the replacement. + + This method should: + 1. Insert the replacement nodes after the last node in the match. + 2. Rebind the users of nodes in the match to use the new nodes. + 3. Set meta["val"] for de-functionalization. + + The result of an auto-functionalized node is a tuple of tensors. + The first element is the return value of the function, usually None. + The remaining elements are the mutated args of the function. + + All auto-functionalized nodes must contain a proper meta["val"], + as it is used by de-functionalization. meta["val"] has to contain the + value of the node (tuple of tensors) that would be returned by the + functionalized node during tracing. + + Existing nodes in the graph all have this property set, but we have + to set it manually for new nodes we insert. + + Example: + # op schema: foo(a: Tensor!, b: Tensor, c: Tensor!) -> None + at = auto_functionalized(torch.ops._C.foo.default, a, b, c) + # at.meta["val"] = (None, a, c) + """ + raise NotImplementedError + + @property + def nodes(self) -> List[fx.Node]: + return self.match.nodes + + @property + def graph(self) -> fx.Graph: + return self.match.graph + + def find_auto_fn(self, op) -> fx.Node: + """ + Find the first auto_functionalized node with the given op in the match. + """ + return find_auto_fn(self.nodes, op) + + def inserting_after_match(self): + """ + Insert nodes after the last node in the match. + This is done to avoid use-before-definition errors after inserting + replacement nodes. + """ + + # match.nodes is not guaranteed to be sorted. + # Find the last node in the match. + for last_node_in_match in reversed(self.graph.nodes): + if last_node_in_match in self.match.nodes: + break + else: + raise ValueError("No nodes in graph") + + return self.graph.inserting_after(last_node_in_match) + + def insert_getitems(self, tuple_node: fx.Node, + indices: Iterable[int]) -> Tuple[fx.Node, ...]: + """ + Insert operator.getitem nodes to extract elements from a tuple node. + + :param tuple_node: The tuple node to extract elements from. + :param indices: The indices of the elements to extract. + :return: Tuple of the new getitem nodes, corresponding to the indices. + """ + with self.graph.inserting_after(tuple_node): + return tuple( + self.graph.call_function(operator.getitem, (tuple_node, idx)) + for idx in indices) + + def insert_auto_fn(self, op: OpOverload, kwargs): + """ + Insert an auto_functionalized node with the given op and kwargs. + """ + return self.graph.call_function(auto_functionalized, (op, ), + kwargs=kwargs) diff --git a/vllm/compilation/reshapes.py b/vllm/compilation/reshapes.py index 63a369fe8d966..ba28b1f0be7bd 100644 --- a/vllm/compilation/reshapes.py +++ b/vllm/compilation/reshapes.py @@ -5,7 +5,8 @@ from vllm.logger import init_logger -from .vllm_inductor_pass import VllmInductorPass, is_func +from .fx_utils import is_func +from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index dbf6b8f7789e1..b8c52a7f46838 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -16,10 +16,6 @@ logger = init_logger(__name__) -def is_func(node: torch.fx.Node, target) -> bool: - return node.op == "call_function" and node.target == target - - class VllmInductorPass(InductorPass): """ An inductor pass with access to vLLM PassConfig. From 0a9a96faa56110f9b3cd2e202a6904196ff29c51 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 10 Dec 2024 22:02:23 +0000 Subject: [PATCH 18/21] PR comments: kernels Signed-off-by: luka --- csrc/quantization/fp8/common.cuh | 24 ++++--------------- ...fused_layernorm_dynamic_per_token_quant.cu | 1 + .../fused_kernels/layernorm_utils.cuh | 4 +++- .../fused_kernels/quant_conversions.cuh | 18 +++++++------- .../{fused_kernels => }/vectorization.cuh | 10 ++++++-- 5 files changed, 27 insertions(+), 30 deletions(-) rename csrc/quantization/{fused_kernels => }/vectorization.cuh (59%) diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index d7c0297d5333f..ca192b1db6528 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -1,5 +1,7 @@ #pragma once +#include "quantization/vectorization.cuh" + #include #ifndef USE_ROCM @@ -89,22 +91,6 @@ __global__ void segmented_max_reduction(float* __restrict__ scale, } } -template -struct __align__(8) vec4_t { - scalar_t x; - scalar_t y; - scalar_t z; - scalar_t w; -}; - -typedef struct __align__(4) { - FP8_TYPE x; - FP8_TYPE y; - FP8_TYPE z; - FP8_TYPE w; -} -float8x4_t; - template __device__ float thread_max_vec(scalar_t const* __restrict__ input, int64_t const num_elems, int const tid, @@ -139,10 +125,10 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out, float const scale, int64_t const num_elems, int const tid, int const step) { + using float8x4_t = q8x4_t; // Vectorized input/output to better utilize memory bandwidth. - vec4_t const* vectorized_in = - reinterpret_cast const*>(input); - float8x4_t* vectorized_out = reinterpret_cast(out); + auto const* vectorized_in = reinterpret_cast const*>(input); + auto* vectorized_out = reinterpret_cast(out); int64_t const num_vec_elems = num_elems >> 2; diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 7bb406458983b..7e2c8f9f83a7e 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -150,6 +150,7 @@ void rms_norm_dynamic_per_token_quant( if (scale_ub.has_value()) { TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn); } + TORCH_CHECK(scales.dtype() == torch::kFloat32); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] { diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index 05ed221930dc8..f729fe58c0c08 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -4,7 +4,7 @@ * __device__ layernorm utilities. */ -#include "vectorization.cuh" +#include "quantization/vectorization.cuh" #include "quant_conversions.cuh" #ifndef USE_ROCM @@ -279,6 +279,8 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, int32_t const num_vec_elems = hidden_size >> 2; +// TODO(luka/varun) extract into type-agnostic vectorized quant function to +// replace scaled_fp8_conversion_vec #pragma unroll 4 for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { vec4_t const in = vec_input[i]; diff --git a/csrc/quantization/fused_kernels/quant_conversions.cuh b/csrc/quantization/fused_kernels/quant_conversions.cuh index abf32f40c4b0e..f8a9872226a3a 100644 --- a/csrc/quantization/fused_kernels/quant_conversions.cuh +++ b/csrc/quantization/fused_kernels/quant_conversions.cuh @@ -4,10 +4,14 @@ * __device__ helper functions to deal with float -> quant datatype conversion */ -#include "vectorization.cuh" +#include "quantization/vectorization.cuh" +// TODO(luka/varun):refactor common.cuh to use this file instead +#include "quantization/fp8/common.cuh" namespace vllm { +// TODO(luka/varun): combine into common utilities for int8 +// (with int8_quant_kernels.cu) static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) { #ifdef USE_ROCM static const float i8_min = @@ -27,11 +31,9 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) { #endif } -#define FP8_E4M3_MAX std::numeric_limits::max() -static __device__ __forceinline__ c10::Float8_e4m3fn float_to_fp8( - float const x) { +static __device__ __forceinline__ FP8_TYPE float_to_fp8(float const x) { float const r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); - return static_cast(r); + return static_cast(r); } template @@ -52,9 +54,9 @@ struct ScaledQuant< }; template -struct ScaledQuant>> { +struct ScaledQuant< + quant_type_t, is_scale_inverted, + typename std::enable_if_t>> { static __device__ __forceinline__ quant_type_t quant_fn(float const x, float const scale) { if constexpr (is_scale_inverted) { diff --git a/csrc/quantization/fused_kernels/vectorization.cuh b/csrc/quantization/vectorization.cuh similarity index 59% rename from csrc/quantization/fused_kernels/vectorization.cuh rename to csrc/quantization/vectorization.cuh index 7ba0df6b11ce4..44c999130f756 100644 --- a/csrc/quantization/fused_kernels/vectorization.cuh +++ b/csrc/quantization/vectorization.cuh @@ -1,8 +1,13 @@ #pragma once /** - * __device__ algorithms that perform vectorized loads/stores of input/output. + * __device__ datatypes vectorized by 4 */ +// Include both AMD and NVIDIA fp8 types to avoid circular import +// TODO(luka/varun) use FP8_TYPE instead after refactoring +#include +#include + namespace vllm { // Vectorization containers @@ -17,7 +22,8 @@ struct __align__(8) vec4_t { template struct __align__(4) q8x4_t { static_assert(std::is_same_v || - std::is_same_v); + std::is_same_v || + std::is_same_v); quant_type_t x; quant_type_t y; quant_type_t z; From 720d537d70a1b4bf9d7202b67046dbe1546e7cbc Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 10 Dec 2024 22:02:33 +0000 Subject: [PATCH 19/21] PR comments: opcheck Signed-off-by: luka --- tests/kernels/test_fused_quant_layernorm.py | 47 +++++++++++++-------- vllm/_custom_ops.py | 14 ------ 2 files changed, 29 insertions(+), 32 deletions(-) diff --git a/tests/kernels/test_fused_quant_layernorm.py b/tests/kernels/test_fused_quant_layernorm.py index 3997f4e9b8fe9..ff8e807ecb600 100644 --- a/tests/kernels/test_fused_quant_layernorm.py +++ b/tests/kernels/test_fused_quant_layernorm.py @@ -4,6 +4,7 @@ import torch import vllm._custom_ops as ops +from tests.kernels.utils import opcheck from vllm.model_executor.layers.layernorm import RMSNorm DTYPES = [torch.bfloat16, torch.float] @@ -27,11 +28,11 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: return torch.as_tensor(x, dtype=torch.float32, device='cuda') + def ref_rms_norm(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: Optional[torch.Tensor]) \ - -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - + -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if residual is not None: residual = residual.clone() out, residual = rms_norm_layer.forward_native(x, residual) @@ -40,13 +41,13 @@ def ref_rms_norm(rms_norm_layer: RMSNorm, return out, residual -def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm, - x: torch.Tensor, - quant_dtype: torch.dtype, - residual: Optional[torch.Tensor], - scale_ub: Optional[torch.Tensor]) \ - -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor]) \ + -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: if scale_ub is not None: assert quant_dtype == torch.float8_e4m3fn @@ -64,22 +65,23 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm, return torch_out, scales, residual + def ref_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, quant_dtype: torch.dtype, residual: Optional[torch.Tensor], scale_ub: Optional[torch.Tensor]) \ - -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype, residual, scale_ub) -def ops_dynamic_per_token_quant(weight: torch.Tensor, - x: torch.Tensor, - quant_dtype: torch.dtype, - residual: Optional[torch.Tensor], - scale_ub: Optional[torch.Tensor]) \ - -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +def ops_dynamic_per_token_quant(weight: torch.Tensor, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor]) \ + -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: if residual is not None: residual = residual.clone() out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS, @@ -87,12 +89,13 @@ def ops_dynamic_per_token_quant(weight: torch.Tensor, residual) return out, scales, residual + def ops_impl(weight: torch.Tensor, x: torch.Tensor, quant_dtype: torch.dtype, residual: Optional[torch.Tensor], scale_ub: Optional[torch.Tensor]) \ - -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, scale_ub) @@ -139,9 +142,9 @@ def test_rms_norm( scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device='cuda') ref_out, ref_scales, ref_residual = \ - ref_impl(layer, x, quant_dtype, residual, scale_ub) + ref_impl(layer, x, quant_dtype, residual, scale_ub) ops_out, ops_scales, ops_residual = \ - ops_impl(layer.weight, x, quant_dtype, residual, scale_ub) + ops_impl(layer.weight, x, quant_dtype, residual, scale_ub) assert ref_out.dtype == quant_dtype assert ops_out.dtype == quant_dtype @@ -154,3 +157,11 @@ def test_rms_norm( ops_out.to(dtype=torch.float32)) if add_residual: assert torch.allclose(ref_residual, ops_residual) + + output = torch.empty_like(x, dtype=quant_dtype) + scales = torch.empty((x.numel() // x.shape[-1], 1), + device=x.device, + dtype=torch.float32) + + opcheck(torch.ops._C.rms_norm_dynamic_per_token_quant, + (output, x, layer.weight, scales, 1e-5, scale_ub, residual)) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 1964b934e1986..d6002630ee02c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -269,20 +269,6 @@ def rms_norm_dynamic_per_token_quant( return output, scales -if hasattr(torch.ops._C, "rms_norm_dynamic_per_token_quant"): - - @register_fake("_C::rms_norm_dynamic_per_token_quant") - def _rms_norm_dynamic_per_token_quant_fake( - output: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - scales: torch.Tensor, - epsilon: float, - scale_ub: Optional[torch.Tensor] = None, - residual: Optional[torch.Tensor] = None) -> None: - return None - - # quantization ops # awq def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, From a70d496a8256d7a4b4a2adb6cce2ca210a0f8af9 Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 12 Dec 2024 21:37:25 +0000 Subject: [PATCH 20/21] Fix dispatch utils for AMD Signed-off-by: luka --- csrc/dispatch_utils.h | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index aa5c8dbbae182..03414b7e1ae93 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -14,9 +14,16 @@ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) -#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) +// TODO(luka/varun): use FP8_TYPE macro after refactoring +#ifndef USE_ROCM + #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) +#else + #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) +#endif #define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) From 8be34ea6e02d4b270e597998caf9b6e0368237ba Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 12 Dec 2024 23:30:51 +0000 Subject: [PATCH 21/21] PR comments: - add kFp8Type constant for cuda/hip agnostic torch type checking - check contiguous - overflow - reduce number of tests Signed-off-by: luka --- csrc/quantization/fp8/common.cuh | 2 ++ .../fused_layernorm_dynamic_per_token_quant.cu | 6 +++--- .../fused_kernels/layernorm_utils.cuh | 16 ++++++++++------ tests/kernels/test_fused_quant_layernorm.py | 16 ++++++++++------ 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index ca192b1db6528..15bd5b6ed1564 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -3,6 +3,7 @@ #include "quantization/vectorization.cuh" #include +#include #ifndef USE_ROCM #include @@ -17,6 +18,7 @@ using FP8_TYPE = c10::Float8_e4m3fnuz; // issue when running dynamic quantization. Here use 224.0f for rocm. constexpr auto FP8_E4M3_MAX = 224.0f; #endif +constexpr static auto kFp8Type = c10::CppTypeToScalarType::value; namespace vllm { diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 7e2c8f9f83a7e..3c4f183bf4b59 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -144,11 +144,11 @@ void rms_norm_dynamic_per_token_quant( torch::Tensor& scales, // [num_tokens] double const var_epsilon, // Variance epsilon used in norm calculation std::optional scale_ub, std::optional residual) { - TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn || - out.dtype() == torch::kInt8); + TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); + TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); if (scale_ub.has_value()) { - TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(out.dtype() == kFp8Type); } TORCH_CHECK(scales.dtype() == torch::kFloat32); diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index f729fe58c0c08..cec6b54edb569 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -20,7 +20,7 @@ template __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, int32_t const hidden_size, float const epsilon, scalar_t const* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * hidden_size; + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); // sum of squares float ss = 0.0f; @@ -53,7 +53,8 @@ __device__ void compute_dynamic_per_token_scales( float const rms, float const* __restrict__ scale_ub, float const min_scaling_factor, int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * hidden_size; + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + ; constexpr scalar_out_t qmax{std::numeric_limits::max()}; float block_absmax_val_maybe = 0.0f; @@ -99,7 +100,8 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, float const rms, float const scale, int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * hidden_size; + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + ; for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { float x = static_cast(input[token_offset + i]); @@ -123,7 +125,7 @@ template __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, int32_t const hidden_size, float const epsilon, scalar_t const* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * hidden_size; + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); // Vectorized input/output to better utilize memory bandwidth. vec4_t const* vec_input = @@ -184,7 +186,8 @@ __device__ void compute_dynamic_per_token_scales( float const rms, float const* __restrict__ scale_ub, float const min_scaling_factor, int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * hidden_size; + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + ; // Vectorized input/weight/residual to better utilize memory bandwidth. vec4_t const* vec_input = @@ -263,7 +266,8 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, float const rms, float const scale, int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * hidden_size; + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + ; // Vectorized input/output/weight/residual to better utilize memory bandwidth. vec4_t const* vec_input = diff --git a/tests/kernels/test_fused_quant_layernorm.py b/tests/kernels/test_fused_quant_layernorm.py index ff8e807ecb600..baf8d73fdbffb 100644 --- a/tests/kernels/test_fused_quant_layernorm.py +++ b/tests/kernels/test_fused_quant_layernorm.py @@ -9,10 +9,15 @@ DTYPES = [torch.bfloat16, torch.float] QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn] -NUM_TOKENS = [1, 7, 83, 2048, 4096] # Arbitrary values for testing -HIDDEN_SIZES = [1, 3, 4, 16, 64, 2048, 5120, - 5137] # Arbitrary values for testing -HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases +VEC_HIDDEN_SIZES = range(1024, 1030) +# Avoid combinatorial explosion with full Cartesian product +NUM_TOKENS_HIDDEN_SIZES = [ + *[(1, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5120, 5137]], + *[(83, i) for i in [1, 1033, 2048, 5120]], + *[(2048, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5137]], + *[(4096, i) for i in [1, 64, 5137]], +] + ADD_RESIDUAL = [False, True] SCALE_UBS = [True, False] SEEDS = [0] @@ -100,8 +105,7 @@ def ops_impl(weight: torch.Tensor, scale_ub) -@pytest.mark.parametrize("num_tokens", NUM_TOKENS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES) @pytest.mark.parametrize("add_residual", ADD_RESIDUAL) @pytest.mark.parametrize("scale_ub", SCALE_UBS) @pytest.mark.parametrize("dtype", DTYPES)