Skip to content

Commit 0a5b2d9

Browse files
Enable Sub Tensor with new Neutron flow
1 parent 5b89d23 commit 0a5b2d9

3 files changed

Lines changed: 323 additions & 15 deletions

File tree

backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import torch
7+
8+
from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT
69
from executorch.backends.nxp.backend.ir.converter.node_converter import (
710
CustomDelegationOptions,
811
NodeConverter,
@@ -23,11 +26,33 @@ def _is_supported_on_target(
2326
parameters_mapping: dict[str, Parameter],
2427
custom_delegation_options: CustomDelegationOptions,
2528
) -> bool:
26-
if NodeConverter.uses_shape_broadcasting(node):
27-
# Shape broadcasting may require the addition of `Transpose` ops during conversion.
28-
return False
29+
if custom_delegation_options.use_new_flow_neutron_c:
30+
if not NodeConverter.at_least_one_input_shape_matches_the_output_shape(
31+
node
32+
):
33+
return False
2934

30-
return True
35+
# If one input is in channel first and ranks of input tensors are not equal, we need to add Transposes
36+
# Transpose is currently not supported for new flow
37+
if any(
38+
input_node.meta[NXP_NODE_FORMAT].is_channels_first()
39+
for input_node in node.all_input_nodes
40+
) and NodeConverter._node_inputs_ranks_not_equal(node):
41+
return False
42+
43+
supported_types = [torch.int8, torch.uint8]
44+
if not NodeConverter.uses_quantization_type_for_io(
45+
node, supported_types, [0, 1], [0]
46+
):
47+
return False
48+
49+
return True
50+
else:
51+
if NodeConverter.uses_shape_broadcasting(node):
52+
# Shape broadcasting may require the addition of `Transpose` ops during conversion.
53+
return False
54+
55+
return True
3156

3257
@staticmethod
3358
def _is_supported_in_IR(
@@ -45,9 +70,12 @@ def _is_supported_in_IR(
4570

4671
return True
4772

48-
# sub.Tensor Node format: (Tensor self, Tensor other, *, Scalar alpha=1)
4973
def convert(self, node: Node):
50-
"""Convert 'sub_tensor' operator to NeutronIR 'Sub'."""
74+
"""Convert 'sub_tensor' operator to NeutronIR 'Sub'.
75+
The ExecuTorch schema is:
76+
sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1)
77+
"""
78+
5179
self.assert_convertible(node)
5280

5381
t_op = self._create_tflite_op_with_io_tensors(node)

0 commit comments

Comments
 (0)