1- # Copyright 2025 Arm Limited and/or its affiliates.
1+ # Copyright 2025-2026 Arm Limited and/or its affiliates.
22
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
@@ -46,51 +46,61 @@ def forward(self, *args, **kwargs):
4646example_input = torch .rand (1 , 6 , 16 , 16 )
4747
4848module_tests = [
49+ # (module, test_tuple, kwargs)
4950 (
5051 make_module_wrapper (
5152 "EmbeddingModule" ,
5253 lambda : torch .nn .Embedding (10 , 10 ),
5354 ),
5455 (torch .LongTensor ([[1 , 2 , 4 , 5 ], [4 , 3 , 2 , 9 ]]),),
56+ {},
5557 ),
5658 (
5759 make_module_wrapper ("LeakyReLUModule" , torch .nn .LeakyReLU ),
5860 (example_input ,),
61+ {},
5962 ),
6063 (
6164 make_module_wrapper ("BatchNorm1dModule" , lambda : torch .nn .BatchNorm1d (16 )),
6265 (torch .rand (6 , 16 , 16 ),),
66+ {},
6367 ),
6468 (
6569 make_module_wrapper (
6670 "AdaptiveAvgPool2dModule" ,
6771 lambda : torch .nn .AdaptiveAvgPool2d ((12 , 12 )),
6872 ),
6973 (example_input ,),
74+ {},
7075 ),
7176 (
7277 make_module_wrapper (
7378 "ConvTranspose2dModule" , lambda : torch .nn .ConvTranspose2d (6 , 3 , 2 )
7479 ),
7580 (example_input ,),
81+ {},
7682 ),
7783 (
7884 make_module_wrapper ("GRUModule" , lambda : torch .nn .GRU (10 , 20 , 2 )),
7985 (torch .randn (5 , 3 , 10 ), torch .randn (2 , 3 , 20 )),
86+ {},
8087 ),
8188 (
8289 make_module_wrapper ("GroupNormModule" , lambda : torch .nn .GroupNorm (2 , 6 )),
8390 (example_input ,),
91+ {},
8492 ),
8593 (
8694 make_module_wrapper (
8795 "InstanceNorm2dModule" , lambda : torch .nn .InstanceNorm2d (16 )
8896 ),
8997 (example_input ,),
98+ {},
9099 ),
91100 (
92101 make_module_wrapper ("PReLUModule" , torch .nn .PReLU ),
93102 (example_input ,),
103+ {},
94104 ),
95105 (
96106 make_module_wrapper (
@@ -104,6 +114,7 @@ def forward(self, *args, **kwargs):
104114 ),
105115 ),
106116 (torch .rand ((10 , 32 , 64 )), torch .rand ((20 , 32 , 64 ))),
117+ {"atol" : 0.1 },
107118 ),
108119]
109120
@@ -117,7 +128,7 @@ def forward(self, *args, **kwargs):
117128 test_parameters ,
118129)
119130def test_nn_modules_tosa_FP (test_data ):
120- module , inputs = test_data
131+ module , inputs , _ = test_data
121132 pipeline = TosaPipelineFP [input_t ](
122133 module , inputs , "" , use_to_edge_transform_and_lower = True
123134 )
@@ -136,15 +147,10 @@ def test_nn_modules_tosa_FP(test_data):
136147@parametrize (
137148 "test_data" ,
138149 test_parameters ,
139- xfails = {
140- "TransformerModule" : "AssertionError: Output 0 does not match reference output." ,
141- },
142150)
143151def test_nn_modules_tosa_INT (test_data ):
144- module , inputs = test_data
145- pipeline = TosaPipelineINT [input_t ](
146- module , inputs , "" , use_to_edge_transform_and_lower = True
147- )
152+ module , inputs , kwargs = test_data
153+ pipeline = TosaPipelineINT [input_t ](module , inputs , "" , ** kwargs )
148154 pipeline .pop_stage ("check.aten" )
149155 pipeline .pop_stage ("check_count.exir" )
150156 if pipeline .has_stage ("check.quant_nodes" ):
0 commit comments