Skip to content

Commit 63fb2ae

Browse files
committed
Xnnpack: Support clone.default with skip_dim_order=True
With the default XNNPACK test config, skip_dim_order=False rewrites aten.clone.default to dim_order_ops._clone_dim_order.default. That path is already supported through CloneDimOrderConfig. Some XNNPACK export flows use skip_dim_order=True, where aten.clone.default stays as aten.clone.default and is not selected by the partitioner. Adds CloneConfig for dim-order-preserving aten.clone.default nodes so this path is partitioned directly. This reduces delegate splits in the EdgeTAM mask decoder, where profiling exports use skip_dim_order=True. Signed-off-by: Måns Nilsson <mans.nilsson@arm.com> Change-Id: Ic48ec187f26048b68a805c6edd6dad41b3dab481
1 parent ee4c90a commit 63fb2ae

4 files changed

Lines changed: 68 additions & 5 deletions

File tree

backends/xnnpack/operators/op_clone.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -13,6 +14,7 @@
1314
NodeVisitor,
1415
register_node_visitor,
1516
)
17+
from executorch.backends.xnnpack.operators.quant_params import QuantParams
1618
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
1719
XNNCopy,
1820
XNNGraph,
@@ -25,17 +27,26 @@
2527
class CloneVisitor(NodeVisitor):
2628
target = "aten.clone.default"
2729

28-
def __init__(self, *args) -> None:
29-
super().__init__(*args)
30-
3130
def define_node(
3231
self,
3332
node: torch.fx.Node,
3433
xnn_graph: XNNGraph,
3534
vals_to_ids: Dict[torch.fx.Node, int],
3635
debug_handle: int,
3736
) -> None:
38-
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
37+
self.define_tensor(
38+
node,
39+
xnn_graph,
40+
vals_to_ids,
41+
quant_params=QuantParams.from_outputs(node),
42+
)
43+
input_node = get_input_node(node, 0)
44+
self.define_tensor(
45+
input_node,
46+
xnn_graph,
47+
vals_to_ids,
48+
quant_params=QuantParams.from_inputs(input_node, self._exported_program),
49+
)
3950

4051
# Sanity check that the input and output dim order are the same. We don't
4152
# handle dim order conversions yet.

backends/xnnpack/partition/config/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -23,6 +24,7 @@
2324
CatConfig,
2425
CeilConfig,
2526
ClampConfig,
27+
CloneConfig,
2628
CloneDimOrderConfig,
2729
ConstantPadConfig,
2830
CosConfig,
@@ -82,6 +84,7 @@
8284
BMMConfig,
8385
CatConfig,
8486
CeilConfig,
87+
CloneConfig,
8588
CloneDimOrderConfig,
8689
ConstantPadConfig,
8790
ConvolutionConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,27 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
239239
return [ConfigPrecisionType.FP32]
240240

241241

242+
class CloneConfig(GenericNodePartitionerConfig):
243+
target_name = "clone.default"
244+
245+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
246+
return [ConfigPrecisionType.FP32]
247+
248+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
249+
if not self.check_common_constraints(node, ep):
250+
return False
251+
252+
input_meta = node.args[0].meta["val"]
253+
output_meta = node.meta["val"]
254+
input_dim_order = list(input_meta.dim_order())
255+
output_dim_order = list(output_meta.dim_order())
256+
if input_dim_order != output_dim_order:
257+
why(node, reason="Only dim-order preserving clones are supported.")
258+
return False
259+
260+
return True
261+
262+
242263
class ClampConfig(GenericNodePartitionerConfig):
243264
target_name = "clamp.default"
244265

backends/xnnpack/test/ops/test_clone.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -9,7 +10,8 @@
910
import unittest
1011

1112
import torch
12-
from executorch.backends.xnnpack.test.tester import Tester
13+
from executorch.backends.xnnpack.test.tester import Tester, ToEdgeTransformAndLower
14+
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
1315

1416

1517
class TestClone(unittest.TestCase):
@@ -62,6 +64,32 @@ def test_fp32_clone(self):
6264
inputs = (torch.randn(2, 3, 4, 5),)
6365
self._test_clone_partitioned(inputs)
6466

67+
def test_fp32_clone_default_partitions_with_skip_dim_order(self):
68+
"""Test plain aten.clone.default partitioning without dim-order rewrite."""
69+
inputs = (torch.randn(2, 3, 4, 5),)
70+
(
71+
Tester(self.Clone(), inputs)
72+
.export()
73+
.check_count({"torch.ops.aten.clone.default": 1})
74+
.to_edge_transform_and_lower(
75+
ToEdgeTransformAndLower(
76+
edge_compile_config=get_xnnpack_edge_compile_config(
77+
skip_dim_order=True
78+
)
79+
)
80+
)
81+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
82+
.check_not(
83+
[
84+
"executorch_exir_dialects_edge__ops_aten_clone_default",
85+
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default",
86+
]
87+
)
88+
.to_executorch()
89+
.serialize()
90+
.run_method_and_compare_outputs()
91+
)
92+
6593
def test_fp32_clone_2d(self):
6694
"""Test FP32 clone with 2D tensor - should be partitioned"""
6795
inputs = (torch.randn(10, 20),)

0 commit comments

Comments
 (0)