33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
55
6+ import torch
7+
8+ from executorch .backends .nxp .backend .data_format import NXP_NODE_FORMAT
69from executorch .backends .nxp .backend .ir .converter .node_converter import (
710 CustomDelegationOptions ,
811 NodeConverter ,
@@ -23,11 +26,33 @@ def _is_supported_on_target(
2326 parameters_mapping : dict [str , Parameter ],
2427 custom_delegation_options : CustomDelegationOptions ,
2528 ) -> bool :
26- if NodeConverter .uses_shape_broadcasting (node ):
27- # Shape broadcasting may require the addition of `Transpose` ops during conversion.
28- return False
29+ if custom_delegation_options .use_new_flow_neutron_c :
30+ if not NodeConverter .at_least_one_input_shape_matches_the_output_shape (
31+ node
32+ ):
33+ return False
2934
30- return True
35+ # If one input is in channel first and ranks of input tensors are not equal, we need to add Transposes
36+ # Transpose is currently not supported for new flow
37+ if any (
38+ input_node .meta [NXP_NODE_FORMAT ].is_channels_first ()
39+ for input_node in node .all_input_nodes
40+ ) and NodeConverter ._node_inputs_ranks_not_equal (node ):
41+ return False
42+
43+ supported_types = [torch .int8 , torch .uint8 ]
44+ if not NodeConverter .uses_quantization_type_for_io (
45+ node , supported_types , [0 , 1 ], [0 ]
46+ ):
47+ return False
48+
49+ return True
50+ else :
51+ if NodeConverter .uses_shape_broadcasting (node ):
52+ # Shape broadcasting may require the addition of `Transpose` ops during conversion.
53+ return False
54+
55+ return True
3156
3257 @staticmethod
3358 def _is_supported_in_IR (
@@ -45,9 +70,12 @@ def _is_supported_in_IR(
4570
4671 return True
4772
48- # sub.Tensor Node format: (Tensor self, Tensor other, *, Scalar alpha=1)
4973 def convert (self , node : Node ):
50- """Convert 'sub_tensor' operator to NeutronIR 'Sub'."""
74+ """Convert 'sub_tensor' operator to NeutronIR 'Sub'.
75+ The ExecuTorch schema is:
76+ sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1)
77+ """
78+
5179 self .assert_convertible (node )
5280
5381 t_op = self ._create_tflite_op_with_io_tensors (node )
0 commit comments