diff --git a/backends/cadence/hifi/operators/op_bmm.cpp b/backends/cadence/hifi/operators/op_bmm.cpp index 31ae966966..0262703bb7 100644 --- a/backends/cadence/hifi/operators/op_bmm.cpp +++ b/backends/cadence/hifi/operators/op_bmm.cpp @@ -16,8 +16,8 @@ using exec_aten::ScalarType; using executorch::runtime::KernelRuntimeContext; using executorch::runtime::kTensorDimensionLimit; using executorch::runtime::resize_tensor; -using executorch::runtime::tensors_have_same_dim_order; using executorch::runtime::tensor_is_default_dim_order; +using executorch::runtime::tensors_have_same_dim_order; using torch::executor::check_bmm_args; using torch::executor::Error; using torch::executor::get_bmm_out_target_size; @@ -78,16 +78,16 @@ Tensor& bmm_out( WORD32 out_stride = p; WORD32* __restrict__ tmp = - (WORD32* __restrict__)kernels::allocate_temp_memory( - ctx, (batch_size * m * p) * sizeof(float)); + (WORD32* __restrict__)kernels::allocate_temp_memory( + ctx, (batch_size * m * p) * sizeof(float)); ET_KERNEL_CHECK(ctx, tmp != nullptr, MemoryAllocationFailed, out); tmp[batch_size * m * p] = {0}; WORD32* __restrict__ p_o = - (WORD32* __restrict__)kernels::allocate_temp_memory( - ctx, (batch_size * m * p) * sizeof(WORD32)); + (WORD32* __restrict__)kernels::allocate_temp_memory( + ctx, (batch_size * m * p) * sizeof(WORD32)); ET_KERNEL_CHECK(ctx, p_o != nullptr, MemoryAllocationFailed, out); diff --git a/backends/cadence/hifi/operators/op_mm.cpp b/backends/cadence/hifi/operators/op_mm.cpp index ceedc97eeb..334eba0a15 100644 --- a/backends/cadence/hifi/operators/op_mm.cpp +++ b/backends/cadence/hifi/operators/op_mm.cpp @@ -76,8 +76,8 @@ Tensor& mm_out( WORD32 out_stride = p; WORD32* __restrict__ p_o = - (WORD32* __restrict__)kernels::allocate_temp_memory( - ctx, (n * p) * sizeof(WORD32)); + (WORD32* __restrict__)kernels::allocate_temp_memory( + ctx, (n * p) * sizeof(WORD32)); WORD32 p_inp_shape[2]; p_inp_shape[0] = n; diff --git a/backends/qualcomm/_passes/i64_to_i32.py b/backends/qualcomm/_passes/i64_to_i32.py index f13b035552..c2818022e7 100644 --- a/backends/qualcomm/_passes/i64_to_i32.py +++ b/backends/qualcomm/_passes/i64_to_i32.py @@ -31,6 +31,14 @@ class I64toI32(ExportPass): exir_ops.edge.aten.full.default, exir_ops.edge.aten.scalar_tensor.default, } + # This dict is to ensure that the input of the OPs are int64 due to Pytorch restrictions. + # For example, scatter op can only accept args[2], the index, as int64. + # Key: Ops to cast input to i64 + # Value: The args' indices to add casting op + I64_IN_OPS = { + exir_ops.edge.aten.gather.default: [2], + exir_ops.edge.aten.scatter.src: [2], + } copy_op = exir_ops.edge.aten._to_copy.default def __init__( @@ -141,11 +149,32 @@ def _cast_constant_to_int32(self, graph_module: torch.fx.GraphModule): n.replace_all_uses_with(to_dst_node) to_dst_node.args = (n,) + def _cast_op_args_to_i64(self, graph_module: torch.fx.GraphModule): + # input will be cast to i32 during call_operator dtype propogation + # insert i64 cast node to prevent operator validation failure + for node in graph_module.graph.nodes: + if node.target in self.I64_IN_OPS: + with graph_module.graph.inserting_before(node): + arg_indices = self.I64_IN_OPS[node.target] + for arg_index in arg_indices: + input_node = node.args[arg_index] + cast_i64_node = graph_module.graph.create_node( + "call_function", + self.copy_op, + (input_node,), + {"dtype": torch.int64}, + ) + cast_i64_node.meta["val"] = node.meta["val"].to(torch.int64) + args_list = list(node.args) + args_list[arg_index] = cast_i64_node + node.args = tuple(args_list) + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # Record original output dtype to ensure that if user expects int64 as output, # convert the output back to int64 if it is casted from int64->int32. self._record_original_output_dtype(graph_module) self._cast_constant_to_int32(graph_module) + self._cast_op_args_to_i64(graph_module) graph_module = super().call(graph_module).graph_module self._preserve_output_dtype(graph_module) graph_module.recompile() diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index c98f27db12..c1ebc5f577 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -182,6 +182,7 @@ def transform_for_to_edge_pipeline( # Before quantizer def transform_for_annotation_pipeline(self, graph_module: GraphModule): + self.add_pass(RemoveRedundancy(quantization_capture=True)) self.add_pass(ReduceDynamicRange()) self.add_pass(RecomposePixelUnshuffle(quantization_capture=True)) self.add_pass(ReplaceArangeArgs()) diff --git a/backends/qualcomm/_passes/remove_redundancy.py b/backends/qualcomm/_passes/remove_redundancy.py index 07b13d4dd6..d045e7732e 100644 --- a/backends/qualcomm/_passes/remove_redundancy.py +++ b/backends/qualcomm/_passes/remove_redundancy.py @@ -14,9 +14,9 @@ class RemoveRedundancy(ExportPass): Trim certain operators to reduce unnecessary overhead. """ - def __init__(self): + def __init__(self, quantization_capture=False): super(RemoveRedundancy, self).__init__() - self.redundant_ops = { + self.redundant_ops_general = { torch.clone: self._default_condition, torch.ops.aten.clone.default: self._default_condition, exir_ops.edge.aten.clone.default: self._default_condition, @@ -27,7 +27,16 @@ def __init__(self): exir_ops.edge.dim_order_ops._to_dim_order_copy.default: self._dim_order_op_condition, # remove channel_last / contiguous _to_copy if '_skip_dim_order' is set to True exir_ops.edge.aten._to_copy.default: self._to_copy_op_condition, + torch.ops.aten._assert_tensor_metadata.default: self._default_condition, } + self.redundant_ops_annotation = { + torch.ops.aten._assert_tensor_metadata.default: self._default_condition, + } + self.redundant_ops = ( + self.redundant_ops_annotation + if quantization_capture + else self.redundant_ops_general + ) def _dim_order_op_condition(self, node): dim_order = node.kwargs.get("dim_order") @@ -49,6 +58,10 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: continue to_be_remove = n + # assert_tensor_metadata op has no user + if len(n.users.keys()) == 0: + n.args = () + # normal case for user_n in list(n.users.keys()): user_n.replace_input_with(n, n.args[0]) graph_module.graph.erase_node(to_be_remove) diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index 705d5d163c..27faa036dd 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -32,6 +32,7 @@ op_expand, op_full, op_full_like, + op_gather, op_ge, op_gelu, op_group_norm, @@ -120,6 +121,7 @@ op_expand, op_full, op_full_like, + op_gather, op_ge, op_gelu, op_group_norm, diff --git a/backends/qualcomm/builders/op_gather.py b/backends/qualcomm/builders/op_gather.py new file mode 100644 index 0000000000..877928f17b --- /dev/null +++ b/backends/qualcomm/builders/op_gather.py @@ -0,0 +1,101 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import cast, Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import numpy as np +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA +from executorch.exir.dialects._ops import ops as exir_ops + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpCast, OpGatherElements, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Gather(NodeVisitor): + target = ["aten.gather.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + dim = cast(int, node.args[1]) + + indices_node = node.args[2] + indices_tensor = self.get_tensor(indices_node, node) + indices_tensor_wrapper = self.define_tensor( + indices_node, + node, + indices_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + cast_node = self.edge_program.graph.create_node( + "call_function", + exir_ops.edge.aten._to_copy.default, + (indices_node,), + {"dtype": torch.int32}, + ) + cast_node.meta["val"] = indices_node.meta["val"].to(torch.int32) + cast_tensor = self.get_tensor(cast_node, node) + cast_tensor_wrapper = self.define_tensor( + cast_node, + node, + cast_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + # graph is not allowed to be modified in partition stage + # erase it here to prevent lowering failure + self.edge_program.graph.erase_node(cast_node) + cast_op = PyQnnWrapper.PyQnnOpWrapper( + f"{node.name}_cast_i64_to_i32", QNN_OP_PACKAGE_NAME_QTI_AISW, OpCast.op_name + ) + cast_op.AddInputTensors([indices_tensor_wrapper]) + cast_op.AddOutputTensors([cast_tensor_wrapper]) + + gather_input_tensors = [input_tensor_wrapper, cast_tensor_wrapper] + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + gather_output_tensors = [output_tensor_wrapper] + + gather_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpGatherElements.op_name, + ) + gather_op.AddInputTensors(gather_input_tensors) + gather_op.AddOutputTensors(gather_output_tensors) + gather_op.AddScalarParam( + OpGatherElements.param_axis, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(dim)}, + ) + + return [cast_op, gather_op] diff --git a/backends/qualcomm/builders/op_slice_copy.py b/backends/qualcomm/builders/op_slice_copy.py index 8d12e03c0b..c6d52a7430 100644 --- a/backends/qualcomm/builders/op_slice_copy.py +++ b/backends/qualcomm/builders/op_slice_copy.py @@ -50,12 +50,17 @@ def define_node( dim = cast(int, node.args[1]) if dim < 0: dim = dim % len(input_tensor.shape) - start = cast(int, node.args[2]) + + start = 0 if node.args[2] is None else cast(int, node.args[2]) if start < 0: start = start % input_tensor.shape[dim] - end = min(cast(int, node.args[3]), input_tensor.shape[dim]) - if end < 0: - end = end % input_tensor.shape[dim] + + if len(node.args) > 3: + end = min(cast(int, node.args[3]), input_tensor.shape[dim]) + if end < 0: + end = end % input_tensor.shape[dim] + else: + end = input_tensor.shape[dim] input_tensor_rank = len(input_tensor.shape) ranges = [] diff --git a/backends/qualcomm/builders/op_to.py b/backends/qualcomm/builders/op_to.py index 5fb016aef9..9e37823fdf 100644 --- a/backends/qualcomm/builders/op_to.py +++ b/backends/qualcomm/builders/op_to.py @@ -9,6 +9,7 @@ import torch from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS +from executorch.exir.dialects._ops import ops as exir_ops from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpCast, OpConvert, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -90,9 +91,48 @@ def define_node( PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) + node_input_tensors = [input_tensor_wrapper] + + # if the output / input dtype is int64, we should cast it to int32 first + # since int32 is the only source that can be caste into int64 + ops = [] + if ( + ( + node.meta["val"].dtype == torch.int64 + or input_node.meta["val"].dtype == torch.int64 + ) + # no need to add another cast node if the dtype is already integer type + and input_node.meta["val"].dtype not in (torch.int32, torch.int64) + ): + cast_node = self.edge_program.graph.create_node( + "call_function", + exir_ops.edge.aten._to_copy.default, + (input_node,), + {"dtype": torch.int32}, + ) + cast_node.meta["val"] = input_node.meta["val"].to(torch.int32) + cast_tensor = self.get_tensor(cast_node, node) + cast_tensor_wrapper = self.define_tensor( + cast_node, + node, + cast_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + # graph is not allowed to be modified in partition stage + # erase it here to prevent lowering failure + self.edge_program.graph.erase_node(cast_node) + cast_op = PyQnnWrapper.PyQnnOpWrapper( + f"{node.name}_cast_i64_to_i32", + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpCast.op_name, + ) + node_input_tensors = [cast_tensor_wrapper] + cast_op.AddInputTensors([input_tensor_wrapper]) + cast_op.AddOutputTensors([cast_tensor_wrapper]) + ops.append(cast_op) output_tensor = self.get_tensor(node, node) - output_tensor_wrapper = self.define_tensor( node, node, @@ -105,7 +145,8 @@ def define_node( op = PyQnnWrapper.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op.op_name ) - op.AddInputTensors([input_tensor_wrapper]) + op.AddInputTensors(node_input_tensors) op.AddOutputTensors([output_tensor_wrapper]) + ops.append(op) - return op + return ops diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 06e398f7c0..c13a126f76 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -252,6 +252,12 @@ class OpGather: param_axis: str = "axis" +@dataclass(init=False, frozen=True) +class OpGatherElements: + op_name: str = "GatherElements" + param_axis: str = "axis" + + @dataclass(init=False, frozen=True) class OpGatherND: op_name: str = "GatherNd" diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 469a801fee..c2b15c5f22 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -750,7 +750,7 @@ def annotate_elu(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) -@register_annotator([torch.ops.aten.embedding.default]) +@register_annotator([torch.ops.aten.embedding.default, torch.ops.aten.gather.default]) def annotate_embedding(node: Node, quantization_config: QuantizationConfig) -> None: weight = node.args[0] diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index adf6e256f5..69934414a5 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -729,6 +729,34 @@ def forward(self, x): return torch.min(x, torch.full_like(x, self.fill)) +class Gather(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.gather(x, dim=1, index=y) + + +class GatherArgmin(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + indice = torch.argmin(x, dim=1, keepdim=True) + return torch.gather(x, dim=1, index=indice) + + +class GatherWhere(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + index = torch.where(y > 0, torch.Tensor([1]).int(), torch.Tensor([1]).int()).to( + torch.int64 + ) + return torch.gather(x, x.dim() - 1, index) + + class Gelu(torch.nn.Module): def __init__(self): super().__init__() @@ -1398,6 +1426,14 @@ def forward(self, x, y): return x[:, :seq_length] + self.position_ids[:, :seq_length] +class SliceCopyDefaultParameter(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.cat([x[:1], x[1:]], dim=1) + + class SliceCopyWithStep(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 7d097fd45b..3dc5c1fa54 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -478,6 +478,21 @@ def test_qnn_backend_full_like(self): sample_input = (torch.randn(1, 2, 3, 4),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_gather(self): + modules = [Gather(), GatherArgmin(), GatherWhere()] # noqa: F405 + shape = (2, 2, 3, 4) + sample_inputs = [ + ( + torch.arange(128, dtype=torch.float32).view(64, 2), + torch.ones(64, 2, dtype=torch.int64), + ), + (torch.arange(128, dtype=torch.float32).view(64, 2),), + (torch.randn(shape), torch.randn(shape)), + ] + for i, (module, sample_input) in enumerate(zip(modules, sample_inputs)): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_gelu(self): module = Gelu() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),) @@ -821,12 +836,17 @@ def test_qnn_backend_select_copy(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_slice_copy(self): - modules = [SliceCopy(), SliceCopyWithStep()] # noqa: F405 - sample_input = ( - torch.randn([1, 512]), - torch.randn([1, 8]), - ) - for module in modules: + modules = [ + SliceCopyDefaultParameter(), # noqa: F405 + SliceCopy(), # noqa: F405 + SliceCopyWithStep(), # noqa: F405 + ] + sample_inputs = [ + (torch.randn([2, 1, 320, 512]),), + (torch.randn([1, 512]), torch.randn([1, 8])), + (torch.randn([1, 512]), torch.randn([1, 8])), + ] + for module, sample_input in zip(modules, sample_inputs): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_stack(self): @@ -1593,6 +1613,22 @@ def test_qnn_backend_full_like(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_gather(self): + modules = [Gather(), GatherArgmin(), GatherWhere()] # noqa: F405 + shape = (2, 2, 3, 4) + sample_inputs = [ + ( + torch.arange(128, dtype=torch.float32).view(64, 2), + torch.ones(64, 2, dtype=torch.int64), + ), + (torch.arange(128, dtype=torch.float32).view(64, 2),), + (torch.randn(shape), torch.randn(shape)), + ] + for i, (module, sample_input) in enumerate(zip(modules, sample_inputs)): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_gelu(self): module = Gelu() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),) @@ -1991,12 +2027,17 @@ def test_qnn_backend_sin(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_slice_copy(self): - modules = [SliceCopy(), SliceCopyWithStep()] # noqa: F405 - sample_input = ( - torch.randn([1, 512]), - torch.randn([1, 8]), - ) - for module in modules: + modules = [ + SliceCopyDefaultParameter(), # noqa: F405 + SliceCopy(), # noqa: F405 + SliceCopyWithStep(), # noqa: F405 + ] + sample_inputs = [ + (torch.randn([2, 1, 320, 512]),), + (torch.randn([1, 512]), torch.randn([1, 8])), + (torch.randn([1, 512]), torch.randn([1, 8])), + ] + for module, sample_input in zip(modules, sample_inputs): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) diff --git a/extension/llm/tokenizers b/extension/llm/tokenizers index 3f9c458586..35d185e0f5 160000 --- a/extension/llm/tokenizers +++ b/extension/llm/tokenizers @@ -1 +1 @@ -Subproject commit 3f9c458586ee576a7ddafb48eb491f117187e178 +Subproject commit 35d185e0f5e80c261c4ebf4f4993ff55f2792626