Skip to content

Commit 3920485

Browse files
Arm backend: Add TOSA dialect data layout ops
Adds TOSA dialect fake implementations for CONCAT, RESHAPE, REVERSE, TILE and TRANSPOSE. Also moves PAD and SLICE into data_layout_ops.py. Signed-off-by: Oscar Andersson <oscar.andersson@arm.com> Change-Id: I93adb38dcfa4382b0bb60853c45db252de5f4250
1 parent 7282106 commit 3920485

5 files changed

Lines changed: 455 additions & 128 deletions

File tree

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import executorch.backends.arm.tosa.dialect # noqa: F401
7+
import pytest
8+
import torch
9+
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
10+
from executorch.backends.arm.tosa.specification import (
11+
TosaLoweringContext,
12+
TosaSpecification,
13+
)
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from torch._subclasses.fake_tensor import FakeTensorMode
16+
17+
18+
def _fake_tensor(dtype: torch.dtype, mode: FakeTensorMode) -> torch.Tensor:
19+
return mode.from_tensor(torch.empty((2, 3), dtype=dtype))
20+
21+
22+
_DATA_LAYOUT_OPS = [
23+
pytest.param(
24+
lambda x: exir_ops.backend.tosa.CONCAT.default([x, x], axis=0),
25+
(4, 3),
26+
id="concat",
27+
),
28+
pytest.param(
29+
lambda x: exir_ops.backend.tosa.PAD.default(x, [1, 2, 3, 4], value=0),
30+
(5, 10),
31+
id="pad",
32+
),
33+
pytest.param(
34+
lambda x: exir_ops.backend.tosa.RESHAPE.default(x, [3, 2]),
35+
(3, 2),
36+
id="reshape",
37+
),
38+
pytest.param(
39+
lambda x: exir_ops.backend.tosa.REVERSE.default(x, axis=0),
40+
(2, 3),
41+
id="reverse",
42+
),
43+
pytest.param(
44+
lambda x: exir_ops.backend.tosa.SLICE.default(x, [0, 1], [2, 2]),
45+
(2, 2),
46+
id="slice",
47+
),
48+
pytest.param(
49+
lambda x: exir_ops.backend.tosa.TILE.default(x, [1, 2]),
50+
(2, 6),
51+
id="tile",
52+
),
53+
pytest.param(
54+
lambda x: exir_ops.backend.tosa.TRANSPOSE.default(x, [1, 0]),
55+
(3, 2),
56+
id="transpose",
57+
),
58+
]
59+
60+
_POSITIVE_DTYPES = [
61+
pytest.param("TOSA-1.1+FP", torch.float32, id="fp32"),
62+
pytest.param("TOSA-1.1+INT", torch.int32, id="int32"),
63+
pytest.param("TOSA-1.1+FP", torch.bool, id="bool"),
64+
pytest.param("TOSA-1.1+INT+int64", torch.int64, id="int64"),
65+
pytest.param("TOSA-1.1+FP+bf16", torch.bfloat16, id="bf16"),
66+
pytest.param("TOSA-1.1+FP+fp8e4m3", torch.float8_e4m3fn, id="fp8e4m3"),
67+
pytest.param("TOSA-1.1+FP+fp8e5m2", torch.float8_e5m2, id="fp8e5m2"),
68+
]
69+
70+
71+
@pytest.mark.parametrize("spec,dtype", _POSITIVE_DTYPES)
72+
@pytest.mark.parametrize("op,expected_shape", _DATA_LAYOUT_OPS)
73+
def test_data_layout_ops_positive(op, expected_shape, spec, dtype) -> None:
74+
with TosaLoweringContext(
75+
TosaSpecification.create_from_string(spec)
76+
), FakeTensorMode() as mode:
77+
output = op(_fake_tensor(dtype, mode))
78+
79+
assert output.dtype == dtype
80+
assert tuple(output.shape) == expected_shape
81+
82+
83+
@pytest.mark.parametrize(
84+
"op,error_match",
85+
[
86+
pytest.param(
87+
lambda x: exir_ops.backend.tosa.CONCAT.default([x, x], axis=2),
88+
"out of range",
89+
id="concat",
90+
),
91+
pytest.param(
92+
lambda x: exir_ops.backend.tosa.PAD.default(x, [0, -1, 0, 0], value=0),
93+
"non-negative",
94+
id="pad",
95+
),
96+
pytest.param(
97+
lambda x: exir_ops.backend.tosa.RESHAPE.default(x, [-2, -3]),
98+
"Negative dimension",
99+
id="reshape",
100+
),
101+
pytest.param(
102+
lambda x: exir_ops.backend.tosa.REVERSE.default(x, axis=2),
103+
"out of range",
104+
id="reverse",
105+
),
106+
pytest.param(
107+
lambda x: exir_ops.backend.tosa.SLICE.default(x, [0, 0], [2, 0]),
108+
r"Expected start \+ size",
109+
id="slice",
110+
),
111+
pytest.param(
112+
lambda x: exir_ops.backend.tosa.TILE.default(x, [0, 1]),
113+
"TILE multiples must be positive",
114+
id="tile",
115+
),
116+
pytest.param(
117+
lambda x: exir_ops.backend.tosa.TRANSPOSE.default(x, [0, 0]),
118+
"Invalid permutation",
119+
id="transpose",
120+
),
121+
],
122+
)
123+
def test_data_layout_ops_reject_invalid_arguments(op, error_match) -> None:
124+
with TosaLoweringContext(
125+
TosaSpecification.create_from_string("TOSA-1.1+FP")
126+
), FakeTensorMode() as mode:
127+
with pytest.raises(TosaValueError, match=error_match):
128+
op(_fake_tensor(torch.float32, mode))
129+
130+
131+
@pytest.mark.parametrize("op,expected_shape", _DATA_LAYOUT_OPS)
132+
def test_data_layout_ops_reject_int64_without_extension(op, expected_shape) -> None:
133+
with TosaLoweringContext(
134+
TosaSpecification.create_from_string("TOSA-1.1+FP")
135+
), FakeTensorMode() as mode:
136+
with pytest.raises(TosaValueError, match="Unsupported dtype"):
137+
op(_fake_tensor(torch.int64, mode))
138+
139+
140+
def test_int16_data_layout_dtype_support_follows_tosa_spec() -> None:
141+
with TosaLoweringContext(
142+
TosaSpecification.create_from_string("TOSA-1.0+INT")
143+
), FakeTensorMode() as mode:
144+
x = _fake_tensor(torch.int16, mode)
145+
146+
assert exir_ops.backend.tosa.RESHAPE.default(x, [3, 2]).dtype == torch.int16
147+
assert exir_ops.backend.tosa.REVERSE.default(x, axis=0).dtype == torch.int16
148+
assert exir_ops.backend.tosa.TILE.default(x, [1, 1]).dtype == torch.int16
149+
150+
with pytest.raises(TosaValueError, match="Unsupported dtype"):
151+
exir_ops.backend.tosa.CONCAT.default([x, x], axis=0)
152+
153+
with TosaLoweringContext(
154+
TosaSpecification.create_from_string("TOSA-1.0+INT+int16")
155+
), FakeTensorMode() as mode:
156+
x = _fake_tensor(torch.int16, mode)
157+
assert exir_ops.backend.tosa.CONCAT.default([x, x], axis=0).dtype == torch.int16
158+
159+
160+
def test_pad_rejects_wrong_padding_length() -> None:
161+
with TosaLoweringContext(
162+
TosaSpecification.create_from_string("TOSA-1.0+FP")
163+
), FakeTensorMode() as mode:
164+
with pytest.raises(TosaValueError, match="Padding length"):
165+
exir_ops.backend.tosa.PAD.default(
166+
mode.from_tensor(torch.randn((2, 3), dtype=torch.float32)),
167+
[1, 2],
168+
value=0.0,
169+
)
170+
171+
172+
def test_reshape_rejects_size_change():
173+
with TosaLoweringContext(
174+
TosaSpecification.create_from_string("TOSA-1.1+FP")
175+
), FakeTensorMode() as mode:
176+
with pytest.raises(TosaValueError, match="same number of elements"):
177+
exir_ops.backend.tosa.RESHAPE.default(
178+
mode.from_tensor(torch.randn((2, 3), dtype=torch.float32)),
179+
[5],
180+
)

backends/arm/tosa/dialect/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,19 @@
1010
conv2d,
1111
conv3d,
1212
custom,
13+
data_layout_ops,
1314
depthwise_conv2d,
1415
fft,
1516
gather,
1617
identity,
1718
matmul,
1819
max_pool2d,
1920
max_pool2d_adaptive,
20-
pad,
2121
reduction_ops,
2222
rescale,
2323
resize,
2424
scatter,
2525
shape_ops,
26-
slice,
2726
table,
2827
transpose_conv2d,
2928
unary_elementwise,

0 commit comments

Comments
 (0)