Skip to content

Qualcomm AI Engine Direct - xr model enablement (mld_f) #10546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions backends/cadence/hifi/operators/op_bmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ using exec_aten::ScalarType;
using executorch::runtime::KernelRuntimeContext;
using executorch::runtime::kTensorDimensionLimit;
using executorch::runtime::resize_tensor;
using executorch::runtime::tensors_have_same_dim_order;
using executorch::runtime::tensor_is_default_dim_order;
using executorch::runtime::tensors_have_same_dim_order;
using torch::executor::check_bmm_args;
using torch::executor::Error;
using torch::executor::get_bmm_out_target_size;
Expand Down Expand Up @@ -78,16 +78,16 @@ Tensor& bmm_out(
WORD32 out_stride = p;

WORD32* __restrict__ tmp =
(WORD32* __restrict__)kernels::allocate_temp_memory(
ctx, (batch_size * m * p) * sizeof(float));
(WORD32* __restrict__)kernels::allocate_temp_memory(
ctx, (batch_size * m * p) * sizeof(float));

ET_KERNEL_CHECK(ctx, tmp != nullptr, MemoryAllocationFailed, out);

tmp[batch_size * m * p] = {0};

WORD32* __restrict__ p_o =
(WORD32* __restrict__)kernels::allocate_temp_memory(
ctx, (batch_size * m * p) * sizeof(WORD32));
(WORD32* __restrict__)kernels::allocate_temp_memory(
ctx, (batch_size * m * p) * sizeof(WORD32));

ET_KERNEL_CHECK(ctx, p_o != nullptr, MemoryAllocationFailed, out);

Expand Down
4 changes: 2 additions & 2 deletions backends/cadence/hifi/operators/op_mm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ Tensor& mm_out(
WORD32 out_stride = p;

WORD32* __restrict__ p_o =
(WORD32* __restrict__)kernels::allocate_temp_memory(
ctx, (n * p) * sizeof(WORD32));
(WORD32* __restrict__)kernels::allocate_temp_memory(
ctx, (n * p) * sizeof(WORD32));

WORD32 p_inp_shape[2];
p_inp_shape[0] = n;
Expand Down
29 changes: 29 additions & 0 deletions backends/qualcomm/_passes/i64_to_i32.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ class I64toI32(ExportPass):
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.scalar_tensor.default,
}
# This dict is to ensure that the input of the OPs are int64 due to Pytorch restrictions.
# For example, scatter op can only accept args[2], the index, as int64.
# Key: Ops to cast input to i64
# Value: The args' indices to add casting op
I64_IN_OPS = {
exir_ops.edge.aten.gather.default: [2],
exir_ops.edge.aten.scatter.src: [2],
}
copy_op = exir_ops.edge.aten._to_copy.default

def __init__(
Expand Down Expand Up @@ -141,11 +149,32 @@ def _cast_constant_to_int32(self, graph_module: torch.fx.GraphModule):
n.replace_all_uses_with(to_dst_node)
to_dst_node.args = (n,)

def _cast_op_args_to_i64(self, graph_module: torch.fx.GraphModule):
# input will be cast to i32 during call_operator dtype propogation
# insert i64 cast node to prevent operator validation failure
for node in graph_module.graph.nodes:
if node.target in self.I64_IN_OPS:
with graph_module.graph.inserting_before(node):
arg_indices = self.I64_IN_OPS[node.target]
for arg_index in arg_indices:
input_node = node.args[arg_index]
cast_i64_node = graph_module.graph.create_node(
"call_function",
self.copy_op,
(input_node,),
{"dtype": torch.int64},
)
cast_i64_node.meta["val"] = node.meta["val"].to(torch.int64)
args_list = list(node.args)
args_list[arg_index] = cast_i64_node
node.args = tuple(args_list)

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
# Record original output dtype to ensure that if user expects int64 as output,
# convert the output back to int64 if it is casted from int64->int32.
self._record_original_output_dtype(graph_module)
self._cast_constant_to_int32(graph_module)
self._cast_op_args_to_i64(graph_module)
graph_module = super().call(graph_module).graph_module
self._preserve_output_dtype(graph_module)
graph_module.recompile()
Expand Down
1 change: 1 addition & 0 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def transform_for_to_edge_pipeline(

# Before quantizer
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(RemoveRedundancy(quantization_capture=True))
self.add_pass(ReduceDynamicRange())
self.add_pass(RecomposePixelUnshuffle(quantization_capture=True))
self.add_pass(ReplaceArangeArgs())
Expand Down
17 changes: 15 additions & 2 deletions backends/qualcomm/_passes/remove_redundancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ class RemoveRedundancy(ExportPass):
Trim certain operators to reduce unnecessary overhead.
"""

def __init__(self):
def __init__(self, quantization_capture=False):
super(RemoveRedundancy, self).__init__()
self.redundant_ops = {
self.redundant_ops_general = {
torch.clone: self._default_condition,
torch.ops.aten.clone.default: self._default_condition,
exir_ops.edge.aten.clone.default: self._default_condition,
Expand All @@ -27,7 +27,16 @@ def __init__(self):
exir_ops.edge.dim_order_ops._to_dim_order_copy.default: self._dim_order_op_condition,
# remove channel_last / contiguous _to_copy if '_skip_dim_order' is set to True
exir_ops.edge.aten._to_copy.default: self._to_copy_op_condition,
torch.ops.aten._assert_tensor_metadata.default: self._default_condition,
}
self.redundant_ops_annotation = {
torch.ops.aten._assert_tensor_metadata.default: self._default_condition,
}
self.redundant_ops = (
self.redundant_ops_annotation
if quantization_capture
else self.redundant_ops_general
)

def _dim_order_op_condition(self, node):
dim_order = node.kwargs.get("dim_order")
Expand All @@ -49,6 +58,10 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
continue

to_be_remove = n
# assert_tensor_metadata op has no user
if len(n.users.keys()) == 0:
n.args = ()
# normal case
for user_n in list(n.users.keys()):
user_n.replace_input_with(n, n.args[0])
graph_module.graph.erase_node(to_be_remove)
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
op_expand,
op_full,
op_full_like,
op_gather,
op_ge,
op_gelu,
op_group_norm,
Expand Down Expand Up @@ -120,6 +121,7 @@
op_expand,
op_full,
op_full_like,
op_gather,
op_ge,
op_gelu,
op_group_norm,
Expand Down
101 changes: 101 additions & 0 deletions backends/qualcomm/builders/op_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import cast, Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
from executorch.exir.dialects._ops import ops as exir_ops

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpCast, OpGatherElements, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class Gather(NodeVisitor):
target = ["aten.gather.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

dim = cast(int, node.args[1])

indices_node = node.args[2]
indices_tensor = self.get_tensor(indices_node, node)
indices_tensor_wrapper = self.define_tensor(
indices_node,
node,
indices_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

cast_node = self.edge_program.graph.create_node(
"call_function",
exir_ops.edge.aten._to_copy.default,
(indices_node,),
{"dtype": torch.int32},
)
cast_node.meta["val"] = indices_node.meta["val"].to(torch.int32)
cast_tensor = self.get_tensor(cast_node, node)
cast_tensor_wrapper = self.define_tensor(
cast_node,
node,
cast_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
# graph is not allowed to be modified in partition stage
# erase it here to prevent lowering failure
self.edge_program.graph.erase_node(cast_node)
cast_op = PyQnnWrapper.PyQnnOpWrapper(
f"{node.name}_cast_i64_to_i32", QNN_OP_PACKAGE_NAME_QTI_AISW, OpCast.op_name
)
cast_op.AddInputTensors([indices_tensor_wrapper])
cast_op.AddOutputTensors([cast_tensor_wrapper])

gather_input_tensors = [input_tensor_wrapper, cast_tensor_wrapper]
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
gather_output_tensors = [output_tensor_wrapper]

gather_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpGatherElements.op_name,
)
gather_op.AddInputTensors(gather_input_tensors)
gather_op.AddOutputTensors(gather_output_tensors)
gather_op.AddScalarParam(
OpGatherElements.param_axis,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{QCOM_DATA: np.uint32(dim)},
)

return [cast_op, gather_op]
13 changes: 9 additions & 4 deletions backends/qualcomm/builders/op_slice_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,17 @@ def define_node(
dim = cast(int, node.args[1])
if dim < 0:
dim = dim % len(input_tensor.shape)
start = cast(int, node.args[2])

start = 0 if node.args[2] is None else cast(int, node.args[2])
if start < 0:
start = start % input_tensor.shape[dim]
end = min(cast(int, node.args[3]), input_tensor.shape[dim])
if end < 0:
end = end % input_tensor.shape[dim]

if len(node.args) > 3:
end = min(cast(int, node.args[3]), input_tensor.shape[dim])
if end < 0:
end = end % input_tensor.shape[dim]
else:
end = input_tensor.shape[dim]

input_tensor_rank = len(input_tensor.shape)
ranges = []
Expand Down
47 changes: 44 additions & 3 deletions backends/qualcomm/builders/op_to.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
from executorch.exir.dialects._ops import ops as exir_ops

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpCast, OpConvert, QNN_OP_PACKAGE_NAME_QTI_AISW
Expand Down Expand Up @@ -90,9 +91,48 @@ def define_node(
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
node_input_tensors = [input_tensor_wrapper]

# if the output / input dtype is int64, we should cast it to int32 first
# since int32 is the only source that can be caste into int64
ops = []
if (
(
node.meta["val"].dtype == torch.int64
or input_node.meta["val"].dtype == torch.int64
)
# no need to add another cast node if the dtype is already integer type
and input_node.meta["val"].dtype not in (torch.int32, torch.int64)
):
cast_node = self.edge_program.graph.create_node(
"call_function",
exir_ops.edge.aten._to_copy.default,
(input_node,),
{"dtype": torch.int32},
)
cast_node.meta["val"] = input_node.meta["val"].to(torch.int32)
cast_tensor = self.get_tensor(cast_node, node)
cast_tensor_wrapper = self.define_tensor(
cast_node,
node,
cast_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
# graph is not allowed to be modified in partition stage
# erase it here to prevent lowering failure
self.edge_program.graph.erase_node(cast_node)
cast_op = PyQnnWrapper.PyQnnOpWrapper(
f"{node.name}_cast_i64_to_i32",
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpCast.op_name,
)
node_input_tensors = [cast_tensor_wrapper]
cast_op.AddInputTensors([input_tensor_wrapper])
cast_op.AddOutputTensors([cast_tensor_wrapper])
ops.append(cast_op)

output_tensor = self.get_tensor(node, node)

output_tensor_wrapper = self.define_tensor(
node,
node,
Expand All @@ -105,7 +145,8 @@ def define_node(
op = PyQnnWrapper.PyQnnOpWrapper(
node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op.op_name
)
op.AddInputTensors([input_tensor_wrapper])
op.AddInputTensors(node_input_tensors)
op.AddOutputTensors([output_tensor_wrapper])
ops.append(op)

return op
return ops
6 changes: 6 additions & 0 deletions backends/qualcomm/builders/qnn_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,12 @@ class OpGather:
param_axis: str = "axis"


@dataclass(init=False, frozen=True)
class OpGatherElements:
op_name: str = "GatherElements"
param_axis: str = "axis"


@dataclass(init=False, frozen=True)
class OpGatherND:
op_name: str = "GatherNd"
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ def annotate_elu(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.embedding.default])
@register_annotator([torch.ops.aten.embedding.default, torch.ops.aten.gather.default])
def annotate_embedding(node: Node, quantization_config: QuantizationConfig) -> None:
weight = node.args[0]

Expand Down
Loading
Loading