Skip to content

Commit 2327a56

Browse files
committed
Qualcomm AI Engine Direct - xr model enablement (mld_f)
Summary - add gather op support - make cast / slice op more general
1 parent df75088 commit 2327a56

File tree

10 files changed

+230
-22
lines changed

10 files changed

+230
-22
lines changed

backends/qualcomm/_passes/qnn_pass_manager.py

+1
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def transform_for_to_edge_pipeline(
182182

183183
# Before quantizer
184184
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
185+
self.add_pass(RemoveRedundancy(quantization_capture=True))
185186
self.add_pass(ReduceDynamicRange())
186187
self.add_pass(RecomposePixelUnshuffle(quantization_capture=True))
187188
self.add_pass(ReplaceArangeArgs())

backends/qualcomm/_passes/remove_redundancy.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ class RemoveRedundancy(ExportPass):
1414
Trim certain operators to reduce unnecessary overhead.
1515
"""
1616

17-
def __init__(self):
17+
def __init__(self, quantization_capture=False):
1818
super(RemoveRedundancy, self).__init__()
19-
self.redundant_ops = {
19+
self.redundant_ops_general = {
2020
torch.clone: self._default_condition,
2121
torch.ops.aten.clone.default: self._default_condition,
2222
exir_ops.edge.aten.clone.default: self._default_condition,
@@ -27,7 +27,16 @@ def __init__(self):
2727
exir_ops.edge.dim_order_ops._to_dim_order_copy.default: self._dim_order_op_condition,
2828
# remove channel_last / contiguous _to_copy if '_skip_dim_order' is set to True
2929
exir_ops.edge.aten._to_copy.default: self._to_copy_op_condition,
30+
torch.ops.aten._assert_tensor_metadata.default: self._default_condition,
3031
}
32+
self.redundant_ops_annotation = {
33+
torch.ops.aten._assert_tensor_metadata.default: self._default_condition,
34+
}
35+
self.redundant_ops = (
36+
self.redundant_ops_annotation
37+
if quantization_capture
38+
else self.redundant_ops_general
39+
)
3140

3241
def _dim_order_op_condition(self, node):
3342
dim_order = node.kwargs.get("dim_order")
@@ -49,6 +58,10 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
4958
continue
5059

5160
to_be_remove = n
61+
# assert_tensor_metadata op has no user
62+
if len(n.users.keys()) == 0:
63+
n.args = ()
64+
# normal case
5265
for user_n in list(n.users.keys()):
5366
user_n.replace_input_with(n, n.args[0])
5467
graph_module.graph.erase_node(to_be_remove)

backends/qualcomm/builders/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
op_expand,
3333
op_full,
3434
op_full_like,
35+
op_gather,
3536
op_ge,
3637
op_gelu,
3738
op_group_norm,
@@ -120,6 +121,7 @@
120121
op_expand,
121122
op_full,
122123
op_full_like,
124+
op_gather,
123125
op_ge,
124126
op_gelu,
125127
op_group_norm,
+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import cast, Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import numpy as np
11+
import torch
12+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
15+
from .node_visitor import NodeVisitor, register_node_visitor
16+
from .qnn_constants import OpCast, OpGatherElements, QNN_OP_PACKAGE_NAME_QTI_AISW
17+
18+
19+
@register_node_visitor
20+
class Gather(NodeVisitor):
21+
target = ["aten.gather.default"]
22+
23+
def __init__(self, *args) -> None:
24+
super().__init__(*args)
25+
26+
def define_node(
27+
self,
28+
node: torch.fx.Node,
29+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
30+
) -> PyQnnWrapper.PyQnnOpWrapper:
31+
input_node = node.args[0]
32+
input_tensor = self.get_tensor(input_node, node)
33+
input_tensor_wrapper = self.define_tensor(
34+
input_node,
35+
node,
36+
input_tensor,
37+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
38+
nodes_to_wrappers,
39+
)
40+
41+
dim = cast(int, node.args[1])
42+
43+
indices_node = node.args[2]
44+
indices_tensor = self.get_tensor(indices_node, node)
45+
indices_tensor_wrapper = self.define_tensor(
46+
indices_node,
47+
node,
48+
indices_tensor,
49+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
50+
nodes_to_wrappers,
51+
)
52+
53+
cast_node = self.edge_program.graph.create_node(
54+
"call_function",
55+
exir_ops.edge.aten._to_copy.default,
56+
(indices_node,),
57+
{"dtype": torch.int32},
58+
)
59+
cast_node.meta["val"] = indices_node.meta["val"].to(torch.int32)
60+
cast_tensor = self.get_tensor(cast_node, node)
61+
cast_tensor_wrapper = self.define_tensor(
62+
cast_node,
63+
node,
64+
cast_tensor,
65+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
66+
nodes_to_wrappers,
67+
)
68+
# graph is not allowed to be modified in partition stage
69+
# erase it here to prevent lowering failure
70+
self.edge_program.graph.erase_node(cast_node)
71+
cast_op = PyQnnWrapper.PyQnnOpWrapper(
72+
f"{node.name}_cast_i64_to_i32", QNN_OP_PACKAGE_NAME_QTI_AISW, OpCast.op_name
73+
)
74+
cast_op.AddInputTensors([indices_tensor_wrapper])
75+
cast_op.AddOutputTensors([cast_tensor_wrapper])
76+
77+
gather_input_tensors = [input_tensor_wrapper, cast_tensor_wrapper]
78+
output_tensor = self.get_tensor(node, node)
79+
output_tensor_wrapper = self.define_tensor(
80+
node,
81+
node,
82+
output_tensor,
83+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
84+
nodes_to_wrappers,
85+
)
86+
gather_output_tensors = [output_tensor_wrapper]
87+
88+
gather_op = PyQnnWrapper.PyQnnOpWrapper(
89+
node.name,
90+
QNN_OP_PACKAGE_NAME_QTI_AISW,
91+
OpGatherElements.op_name,
92+
)
93+
gather_op.AddInputTensors(gather_input_tensors)
94+
gather_op.AddOutputTensors(gather_output_tensors)
95+
gather_op.AddScalarParam(
96+
OpGatherElements.param_axis,
97+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
98+
{QCOM_DATA: np.uint32(dim)},
99+
)
100+
101+
return [cast_op, gather_op]

backends/qualcomm/builders/op_slice_copy.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,17 @@ def define_node(
5050
dim = cast(int, node.args[1])
5151
if dim < 0:
5252
dim = dim % len(input_tensor.shape)
53-
start = cast(int, node.args[2])
53+
54+
start = 0 if node.args[2] is None else cast(int, node.args[2])
5455
if start < 0:
5556
start = start % input_tensor.shape[dim]
56-
end = min(cast(int, node.args[3]), input_tensor.shape[dim])
57-
if end < 0:
58-
end = end % input_tensor.shape[dim]
57+
58+
if len(node.args) > 3:
59+
end = min(cast(int, node.args[3]), input_tensor.shape[dim])
60+
if end < 0:
61+
end = end % input_tensor.shape[dim]
62+
else:
63+
end = input_tensor.shape[dim]
5964

6065
input_tensor_rank = len(input_tensor.shape)
6166
ranges = []

backends/qualcomm/builders/op_to.py

+40-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import torch
1111
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
12+
from executorch.exir.dialects._ops import ops as exir_ops
1213

1314
from .node_visitor import NodeVisitor, register_node_visitor
1415
from .qnn_constants import OpCast, OpConvert, QNN_OP_PACKAGE_NAME_QTI_AISW
@@ -90,9 +91,44 @@ def define_node(
9091
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
9192
nodes_to_wrappers,
9293
)
94+
node_input_tensors = [input_tensor_wrapper]
95+
96+
# if the output dtype is int64, we should cast it to int32 first
97+
# since int32 is the only source that can be casted into int64
98+
ops = []
99+
if (
100+
node.meta["val"].dtype == torch.int64
101+
or input_node.meta["val"].dtype == torch.int64
102+
):
103+
cast_node = self.edge_program.graph.create_node(
104+
"call_function",
105+
exir_ops.edge.aten._to_copy.default,
106+
(input_node,),
107+
{"dtype": torch.int32},
108+
)
109+
cast_node.meta["val"] = input_node.meta["val"].to(torch.int32)
110+
cast_tensor = self.get_tensor(cast_node, node)
111+
cast_tensor_wrapper = self.define_tensor(
112+
cast_node,
113+
node,
114+
cast_tensor,
115+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
116+
nodes_to_wrappers,
117+
)
118+
# graph is not allowed to be modified in partition stage
119+
# erase it here to prevent lowering failure
120+
self.edge_program.graph.erase_node(cast_node)
121+
cast_op = PyQnnWrapper.PyQnnOpWrapper(
122+
f"{node.name}_cast_i64_to_i32",
123+
QNN_OP_PACKAGE_NAME_QTI_AISW,
124+
OpCast.op_name,
125+
)
126+
node_input_tensors = [cast_tensor_wrapper]
127+
cast_op.AddInputTensors([input_tensor_wrapper])
128+
cast_op.AddOutputTensors([cast_tensor_wrapper])
129+
ops.append(cast_op)
93130

94131
output_tensor = self.get_tensor(node, node)
95-
96132
output_tensor_wrapper = self.define_tensor(
97133
node,
98134
node,
@@ -105,7 +141,8 @@ def define_node(
105141
op = PyQnnWrapper.PyQnnOpWrapper(
106142
node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op.op_name
107143
)
108-
op.AddInputTensors([input_tensor_wrapper])
144+
op.AddInputTensors(node_input_tensors)
109145
op.AddOutputTensors([output_tensor_wrapper])
146+
ops.append(op)
110147

111-
return op
148+
return ops

backends/qualcomm/builders/qnn_constants.py

+6
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,12 @@ class OpGather:
252252
param_axis: str = "axis"
253253

254254

255+
@dataclass(init=False, frozen=True)
256+
class OpGatherElements:
257+
op_name: str = "GatherElements"
258+
param_axis: str = "axis"
259+
260+
255261
@dataclass(init=False, frozen=True)
256262
class OpGatherND:
257263
op_name: str = "GatherNd"

backends/qualcomm/quantizer/annotators.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,7 @@ def annotate_elu(node: Node, quantization_config: QuantizationConfig) -> None:
750750
annotate_single_in_single_out(node, quantization_config)
751751

752752

753-
@register_annotator([torch.ops.aten.embedding.default])
753+
@register_annotator([torch.ops.aten.embedding.default, torch.ops.aten.gather.default])
754754
def annotate_embedding(node: Node, quantization_config: QuantizationConfig) -> None:
755755
weight = node.args[0]
756756

backends/qualcomm/tests/models.py

+19
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,17 @@ def forward(self, x):
729729
return torch.min(x, torch.full_like(x, self.fill))
730730

731731

732+
class Gather(torch.nn.Module):
733+
def __init__(self):
734+
super().__init__()
735+
736+
def forward(self, x, y):
737+
index = torch.where(y > 0, torch.Tensor([1]).int(), torch.Tensor([1]).int()).to(
738+
torch.int64
739+
)
740+
return torch.gather(x, x.dim() - 1, index)
741+
742+
732743
class Gelu(torch.nn.Module):
733744
def __init__(self):
734745
super().__init__()
@@ -1398,6 +1409,14 @@ def forward(self, x, y):
13981409
return x[:, :seq_length] + self.position_ids[:, :seq_length]
13991410

14001411

1412+
class SliceCopyDefaultParameter(torch.nn.Module):
1413+
def __init__(self):
1414+
super().__init__()
1415+
1416+
def forward(self, x):
1417+
return torch.cat([x[:1], x[1:]], dim=1)
1418+
1419+
14011420
class SliceCopyWithStep(torch.nn.Module):
14021421
def __init__(self):
14031422
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

+36-12
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,13 @@ def test_qnn_backend_full_like(self):
478478
sample_input = (torch.randn(1, 2, 3, 4),)
479479
self.lower_module_and_test_output(module, sample_input)
480480

481+
def test_qnn_backend_gather(self):
482+
module = Gather() # noqa: F405
483+
shape = (2, 2, 3, 4)
484+
sample_input = (torch.randn(shape), torch.randn(shape))
485+
module = self.get_qdq_module(module, sample_input)
486+
self.lower_module_and_test_output(module, sample_input)
487+
481488
def test_qnn_backend_gelu(self):
482489
module = Gelu() # noqa: F405
483490
sample_input = (torch.randn(2, 5, 1, 3),)
@@ -821,12 +828,17 @@ def test_qnn_backend_select_copy(self):
821828
self.lower_module_and_test_output(module, sample_input)
822829

823830
def test_qnn_backend_slice_copy(self):
824-
modules = [SliceCopy(), SliceCopyWithStep()] # noqa: F405
825-
sample_input = (
826-
torch.randn([1, 512]),
827-
torch.randn([1, 8]),
828-
)
829-
for module in modules:
831+
modules = [
832+
SliceCopyDefaultParameter(),
833+
SliceCopy(),
834+
SliceCopyWithStep(),
835+
] # noqa: F405
836+
sample_inputs = [
837+
(torch.randn([2, 1, 320, 512]),),
838+
(torch.randn([1, 512]), torch.randn([1, 8])),
839+
(torch.randn([1, 512]), torch.randn([1, 8])),
840+
]
841+
for module, sample_input in zip(modules, sample_inputs):
830842
self.lower_module_and_test_output(module, sample_input)
831843

832844
def test_qnn_backend_stack(self):
@@ -1593,6 +1605,13 @@ def test_qnn_backend_full_like(self):
15931605
module = self.get_qdq_module(module, sample_input)
15941606
self.lower_module_and_test_output(module, sample_input)
15951607

1608+
def test_qnn_backend_gather(self):
1609+
module = Gather() # noqa: F405
1610+
shape = (2, 2, 3, 4)
1611+
sample_input = (torch.randn(shape), torch.randn(shape))
1612+
module = self.get_qdq_module(module, sample_input)
1613+
self.lower_module_and_test_output(module, sample_input)
1614+
15961615
def test_qnn_backend_gelu(self):
15971616
module = Gelu() # noqa: F405
15981617
sample_input = (torch.randn(2, 5, 1, 3),)
@@ -1991,12 +2010,17 @@ def test_qnn_backend_sin(self):
19912010
self.lower_module_and_test_output(module, sample_input)
19922011

19932012
def test_qnn_backend_slice_copy(self):
1994-
modules = [SliceCopy(), SliceCopyWithStep()] # noqa: F405
1995-
sample_input = (
1996-
torch.randn([1, 512]),
1997-
torch.randn([1, 8]),
1998-
)
1999-
for module in modules:
2013+
modules = [
2014+
SliceCopyDefaultParameter(),
2015+
SliceCopy(),
2016+
SliceCopyWithStep(),
2017+
] # noqa: F405
2018+
sample_inputs = [
2019+
(torch.randn([2, 1, 320, 512]),),
2020+
(torch.randn([1, 512]), torch.randn([1, 8])),
2021+
(torch.randn([1, 512]), torch.randn([1, 8])),
2022+
]
2023+
for module, sample_input in zip(modules, sample_inputs):
20002024
module = self.get_qdq_module(module, sample_input)
20012025
self.lower_module_and_test_output(module, sample_input)
20022026

0 commit comments

Comments
 (0)