Skip to content

Commit d001896

Browse files
Arm backend: Unmark Transformer as xfail (pytorch#16930)
Transformer model was previously marked as expected failure due to numerical differences between reference and lowered model. At closer inspection, this difference appear to be within an acceptable range. cc @freddan80 @per @zingo @digantdesai Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent a093fe4 commit d001896

1 file changed

Lines changed: 15 additions & 9 deletions

File tree

backends/arm/test/models/test_nn_modules.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22

33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -46,51 +46,61 @@ def forward(self, *args, **kwargs):
4646
example_input = torch.rand(1, 6, 16, 16)
4747

4848
module_tests = [
49+
# (module, test_tuple, kwargs)
4950
(
5051
make_module_wrapper(
5152
"EmbeddingModule",
5253
lambda: torch.nn.Embedding(10, 10),
5354
),
5455
(torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]),),
56+
{},
5557
),
5658
(
5759
make_module_wrapper("LeakyReLUModule", torch.nn.LeakyReLU),
5860
(example_input,),
61+
{},
5962
),
6063
(
6164
make_module_wrapper("BatchNorm1dModule", lambda: torch.nn.BatchNorm1d(16)),
6265
(torch.rand(6, 16, 16),),
66+
{},
6367
),
6468
(
6569
make_module_wrapper(
6670
"AdaptiveAvgPool2dModule",
6771
lambda: torch.nn.AdaptiveAvgPool2d((12, 12)),
6872
),
6973
(example_input,),
74+
{},
7075
),
7176
(
7277
make_module_wrapper(
7378
"ConvTranspose2dModule", lambda: torch.nn.ConvTranspose2d(6, 3, 2)
7479
),
7580
(example_input,),
81+
{},
7682
),
7783
(
7884
make_module_wrapper("GRUModule", lambda: torch.nn.GRU(10, 20, 2)),
7985
(torch.randn(5, 3, 10), torch.randn(2, 3, 20)),
86+
{},
8087
),
8188
(
8289
make_module_wrapper("GroupNormModule", lambda: torch.nn.GroupNorm(2, 6)),
8390
(example_input,),
91+
{},
8492
),
8593
(
8694
make_module_wrapper(
8795
"InstanceNorm2dModule", lambda: torch.nn.InstanceNorm2d(16)
8896
),
8997
(example_input,),
98+
{},
9099
),
91100
(
92101
make_module_wrapper("PReLUModule", torch.nn.PReLU),
93102
(example_input,),
103+
{},
94104
),
95105
(
96106
make_module_wrapper(
@@ -104,6 +114,7 @@ def forward(self, *args, **kwargs):
104114
),
105115
),
106116
(torch.rand((10, 32, 64)), torch.rand((20, 32, 64))),
117+
{"atol": 0.1},
107118
),
108119
]
109120

@@ -117,7 +128,7 @@ def forward(self, *args, **kwargs):
117128
test_parameters,
118129
)
119130
def test_nn_modules_tosa_FP(test_data):
120-
module, inputs = test_data
131+
module, inputs, _ = test_data
121132
pipeline = TosaPipelineFP[input_t](
122133
module, inputs, "", use_to_edge_transform_and_lower=True
123134
)
@@ -136,15 +147,10 @@ def test_nn_modules_tosa_FP(test_data):
136147
@parametrize(
137148
"test_data",
138149
test_parameters,
139-
xfails={
140-
"TransformerModule": "AssertionError: Output 0 does not match reference output.",
141-
},
142150
)
143151
def test_nn_modules_tosa_INT(test_data):
144-
module, inputs = test_data
145-
pipeline = TosaPipelineINT[input_t](
146-
module, inputs, "", use_to_edge_transform_and_lower=True
147-
)
152+
module, inputs, kwargs = test_data
153+
pipeline = TosaPipelineINT[input_t](module, inputs, "", **kwargs)
148154
pipeline.pop_stage("check.aten")
149155
pipeline.pop_stage("check_count.exir")
150156
if pipeline.has_stage("check.quant_nodes"):

0 commit comments

Comments
 (0)