Skip to content

Commit 3eb57fa

Browse files
authored
Qualcomm AI Engine Direct - Adding QNN backend support for scatter.src core ATen op (#19283)
### Summary Added support for the core ATen op `scatter.src` using an op builder with the [QNN implementation](https://docs.qualcomm.com/doc/80-63442-10/topic/HtpOpDefSupplement.html#scatterelements) for `ScatterElements`. Note `scatter.src` uses `ScatterElements` directly with the argument `reduction=NONE`. ### Test plan ``` python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNQuantizedOperator.test_qnn_backend_scatter_src --model SM8750 --host aisw-vm15-labsd --device 545ee4aa --build_folder build-android python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNFloatingPointOperator.test_qnn_backend_scatter_src --model SM8750 --host aisw-vm15-labsd --device 545ee4aa --build_folder build-android ```
1 parent 477707f commit 3eb57fa

10 files changed

Lines changed: 290 additions & 2 deletions

File tree

backends/qualcomm/_passes/layout_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ class LayoutTransform(ExportPass):
120120
exir_ops.edge.aten.repeat.default,
121121
exir_ops.edge.aten.relu.default,
122122
exir_ops.edge.aten.round.default,
123+
exir_ops.edge.aten.scatter.src,
123124
exir_ops.edge.aten.sigmoid.default,
124125
exir_ops.edge.aten.sign.default,
125126
exir_ops.edge.aten.slice_copy.Tensor,

backends/qualcomm/builders/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ Please help update following table if you are contributing new operators:
368368
+ 🚫 = Deprecated, supported with other QNN Ops
369369

370370

371-
| Operators | HTP - 99/120 Enabled |
371+
| Operators | HTP - 100/120 Enabled |
372372
|-----------|---------|
373373
| Argmax | ✓ |
374374
| Argmin | ✓ |
@@ -473,7 +473,7 @@ Please help update following table if you are contributing new operators:
473473
| ResizeNearestNeighbor | ✓ |
474474
| RoiAlign | ✗ |
475475
| RmsNorm | ✓ |
476-
| ScatterElements | ✗ |
476+
| ScatterElements | ✓ |
477477
| ScatterNd | ✓ |
478478
| Sigmoid | ✓ |
479479
| Softmax | ✓ |

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
op_round,
9191
op_rsqrt,
9292
op_scalar_tensor,
93+
op_scatter_elements,
9394
op_select_copy,
9495
op_sigmoid,
9596
op_sign,
@@ -204,6 +205,7 @@
204205
op_round,
205206
op_rsqrt,
206207
op_scalar_tensor,
208+
op_scatter_elements,
207209
op_select_copy,
208210
op_sigmoid,
209211
op_sign,
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager
9+
10+
import numpy as np
11+
import torch
12+
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
13+
14+
from .node_visitor import NodeVisitor
15+
from .node_visitor_manager import register_node_visitor
16+
from .qnn_constants import OpScatterElements, QNN_OP_PACKAGE_NAME_QTI_AISW
17+
18+
19+
@register_node_visitor
20+
class ScatterElements(NodeVisitor):
21+
target = ["aten.scatter.src"]
22+
23+
def __init__(self, *args) -> None:
24+
super().__init__(*args)
25+
26+
def define_node(
27+
self,
28+
node: torch.fx.Node,
29+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper],
30+
) -> PyQnnManager.PyQnnOpWrapper:
31+
input_node = self.get_node(node.args[0])
32+
input_tensor = self.get_tensor(input_node, node)
33+
input_tensor_wrapper = self.define_tensor(
34+
input_node,
35+
node,
36+
input_tensor,
37+
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
38+
nodes_to_wrappers,
39+
)
40+
41+
index_node = self.get_node(node.args[2])
42+
index_tensor = self.get_tensor(index_node, node)
43+
index_tensor_wrapper = self.define_tensor(
44+
index_node,
45+
node,
46+
index_tensor.to(torch.int32),
47+
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
48+
nodes_to_wrappers,
49+
)
50+
51+
updates_node = self.get_node(node.args[3])
52+
updates_tensor = self.get_tensor(updates_node, node)
53+
updates_tensor_wrapper = self.define_tensor(
54+
updates_node,
55+
node,
56+
updates_tensor,
57+
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
58+
nodes_to_wrappers,
59+
)
60+
61+
output_tensor = self.get_tensor(node, node)
62+
output_tensor_wrapper = self.define_tensor(
63+
node,
64+
node,
65+
output_tensor,
66+
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
67+
nodes_to_wrappers,
68+
)
69+
70+
dim = node.args[1]
71+
if dim < 0:
72+
dim = dim % len(input_tensor.shape)
73+
74+
if QCOM_AXIS_ORDER in node.meta:
75+
dim = node.meta[QCOM_AXIS_ORDER].index(dim)
76+
77+
scatter_op = PyQnnManager.PyQnnOpWrapper(
78+
node.name,
79+
QNN_OP_PACKAGE_NAME_QTI_AISW,
80+
OpScatterElements.op_name,
81+
)
82+
scatter_op.AddInputTensors(
83+
[
84+
input_tensor_wrapper,
85+
index_tensor_wrapper,
86+
updates_tensor_wrapper,
87+
]
88+
)
89+
scatter_op.AddOutputTensors([output_tensor_wrapper])
90+
91+
scatter_op.AddScalarParam(
92+
OpScatterElements.param_axis,
93+
PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
94+
{QCOM_DATA: np.uint32(dim)},
95+
)
96+
97+
scatter_op.AddScalarParam(
98+
OpScatterElements.param_reduction,
99+
PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
100+
{QCOM_DATA: np.uint32(OpScatterElements.Reduction.NONE)},
101+
)
102+
103+
return scatter_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,17 @@ class OpRmsNorm:
594594
param_axes: str = "axes"
595595

596596

597+
@dataclass(init=False, frozen=True)
598+
class OpScatterElements:
599+
op_name: str = "ScatterElements"
600+
param_axis: str = "axis"
601+
param_reduction: str = "reduction"
602+
603+
@unique
604+
class Reduction(IntEnum):
605+
NONE = 0
606+
607+
597608
@dataclass(init=False, frozen=True)
598609
class OpScatterNd:
599610
op_name: str = "ScatterNd"

backends/qualcomm/partition/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]:
6868
torch.ops.aten.reflection_pad2d.default,
6969
torch.ops.aten.rms_norm.default,
7070
torch.ops.aten._safe_softmax.default,
71+
torch.ops.aten.scatter.src,
7172
torch.ops.aten.stack.default,
7273
torch.ops.aten.upsample_bicubic2d.vec,
7374
# This request is ignored because it is in a blocklist. Refer to exir/program/_program.py

backends/qualcomm/quantizer/annotators/htp_rules.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,6 +1391,44 @@ class ScaledDotProductAttention(GeneralOpDef):
13911391
pass
13921392

13931393

1394+
@register_annotator(
1395+
[torch.ops.aten.scatter.src],
1396+
qnn_op=None,
1397+
)
1398+
class ScatterElements(GeneralOpDef):
1399+
@staticmethod
1400+
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
1401+
if _is_annotated([node]):
1402+
return
1403+
1404+
input_act = node.args[0]
1405+
if not isinstance(input_act, Node) or not _is_float_tensor(input_act):
1406+
return
1407+
1408+
input_qspec_map = {}
1409+
input_qspec_map[input_act] = quantization_config.input_activation
1410+
1411+
if (
1412+
len(node.args) > 3
1413+
and isinstance(node.args[3], Node)
1414+
and _is_float_tensor(node.args[3])
1415+
):
1416+
input_qspec_map[node.args[3]] = SharedQuantizationSpec((input_act, node))
1417+
1418+
output_act_qspec = (
1419+
SharedQuantizationSpec((input_act, node))
1420+
if _is_float_tensor(node)
1421+
else None
1422+
)
1423+
1424+
if len(input_qspec_map) > 0 or output_act_qspec is not None:
1425+
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
1426+
input_qspec_map=input_qspec_map,
1427+
output_qspec=output_act_qspec,
1428+
_annotated=True,
1429+
)
1430+
1431+
13941432
@register_annotator(
13951433
[torch.ops.aten.sigmoid, torch.ops.aten.sigmoid.default],
13961434
QnnConstants.OpSigmoid.op_name,

backends/qualcomm/quantizer/annotators/lpai_rules.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,44 @@ class ScaledDotProductAttention(GeneralOpDef):
869869
pass
870870

871871

872+
@register_annotator(
873+
[torch.ops.aten.scatter.src],
874+
qnn_op=None,
875+
)
876+
class ScatterElements(GeneralOpDef):
877+
@staticmethod
878+
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
879+
if _is_annotated([node]):
880+
return
881+
882+
input_act = node.args[0]
883+
if not isinstance(input_act, Node) or not _is_float_tensor(input_act):
884+
return
885+
886+
input_qspec_map = {}
887+
input_qspec_map[input_act] = quantization_config.input_activation
888+
889+
if (
890+
len(node.args) > 3
891+
and isinstance(node.args[3], Node)
892+
and _is_float_tensor(node.args[3])
893+
):
894+
input_qspec_map[node.args[3]] = SharedQuantizationSpec((input_act, node))
895+
896+
output_act_qspec = (
897+
SharedQuantizationSpec((input_act, node))
898+
if _is_float_tensor(node)
899+
else None
900+
)
901+
902+
if len(input_qspec_map) > 0 or output_act_qspec is not None:
903+
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
904+
input_qspec_map=input_qspec_map,
905+
output_qspec=output_act_qspec,
906+
_annotated=True,
907+
)
908+
909+
872910
@register_annotator(
873911
[torch.ops.aten.sigmoid, torch.ops.aten.sigmoid.default],
874912
QnnConstants.OpSigmoid.op_name,

backends/qualcomm/tests/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2201,6 +2201,15 @@ def forward(self, query_layer, key_layer, value_layer, attn_mask):
22012201
return attn_output
22022202

22032203

2204+
class ScatterSrc(torch.nn.Module):
2205+
def __init__(self, dim=1):
2206+
super().__init__()
2207+
self.dim = dim
2208+
2209+
def forward(self, data, index, src):
2210+
return torch.scatter(data, self.dim, index, src)
2211+
2212+
22042213
class SelectCopy(torch.nn.Module):
22052214
def __init__(self):
22062215
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1930,6 +1930,52 @@ def test_qnn_backend_round(self):
19301930
sample_input = (torch.randn([3, 4]),)
19311931
self.lower_module_and_test_output(module, sample_input)
19321932

1933+
def test_qnn_backend_scatter_src(self):
1934+
test_comb = [
1935+
{
1936+
QCOM_MODULE: [ScatterSrc(dim=1)], # noqa: F405
1937+
QCOM_SAMPLE_INPUTS: [
1938+
(
1939+
torch.zeros(3, 5),
1940+
torch.tensor(
1941+
[[0, 1, 2, 3, 4], [4, 3, 2, 1, 0], [1, 0, 3, 4, 2]],
1942+
dtype=torch.int64,
1943+
),
1944+
torch.rand(3, 5),
1945+
),
1946+
(
1947+
torch.zeros(3, 5, dtype=torch.float16),
1948+
torch.tensor(
1949+
[[0, 1, 2, 3, 4], [4, 3, 2, 1, 0], [1, 0, 3, 4, 2]],
1950+
dtype=torch.int64,
1951+
),
1952+
torch.rand(3, 5, dtype=torch.float16),
1953+
),
1954+
],
1955+
},
1956+
{
1957+
QCOM_MODULE: [ScatterSrc(dim=0)], # noqa: F405
1958+
QCOM_SAMPLE_INPUTS: [
1959+
(
1960+
torch.zeros(3, 5),
1961+
torch.tensor(
1962+
[[2, 1, 0, 1, 2], [0, 2, 1, 2, 0], [1, 0, 2, 0, 1]],
1963+
dtype=torch.int64,
1964+
),
1965+
torch.rand(3, 5),
1966+
),
1967+
],
1968+
},
1969+
]
1970+
1971+
index = 0
1972+
for comb in test_comb:
1973+
for module in comb[QCOM_MODULE]:
1974+
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
1975+
with self.subTest(i=index):
1976+
index += 1
1977+
self.lower_module_and_test_output(module, sample_input)
1978+
19331979
def test_qnn_backend_rsqrt(self):
19341980
module = Rsqrt() # noqa: F405
19351981
sample_input = (torch.abs(torch.randn([3, 4])),)
@@ -4722,6 +4768,45 @@ def test_qnn_backend_rsqrt(self):
47224768
module = self.get_qdq_module(module, sample_input)
47234769
self.lower_module_and_test_output(module, sample_input)
47244770

4771+
def test_qnn_backend_scatter_src(self):
4772+
test_comb = [
4773+
{
4774+
QCOM_MODULE: [ScatterSrc(dim=1)], # noqa: F405
4775+
QCOM_SAMPLE_INPUTS: [
4776+
(
4777+
torch.zeros(3, 5),
4778+
torch.tensor(
4779+
[[0, 1, 2, 3, 4], [4, 3, 2, 1, 0], [1, 0, 3, 4, 2]],
4780+
dtype=torch.int64,
4781+
),
4782+
torch.rand(3, 5),
4783+
),
4784+
],
4785+
},
4786+
{
4787+
QCOM_MODULE: [ScatterSrc(dim=0)], # noqa: F405
4788+
QCOM_SAMPLE_INPUTS: [
4789+
(
4790+
torch.zeros(3, 5),
4791+
torch.tensor(
4792+
[[2, 1, 0, 1, 2], [0, 2, 1, 2, 0], [1, 0, 2, 0, 1]],
4793+
dtype=torch.int64,
4794+
),
4795+
torch.rand(3, 5),
4796+
),
4797+
],
4798+
},
4799+
]
4800+
4801+
index = 0
4802+
for comb in test_comb:
4803+
for module in comb[QCOM_MODULE]:
4804+
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
4805+
with self.subTest(i=index):
4806+
index += 1
4807+
qdq_module = self.get_qdq_module(module, sample_input)
4808+
self.lower_module_and_test_output(qdq_module, sample_input)
4809+
47254810
def test_qnn_backend_sdpa(self):
47264811
modules = [
47274812
ScaledDotProductAttention(), # noqa: F405

0 commit comments

Comments
 (0)