Skip to content

Commit 8805beb

Browse files
Add a16w8 per-op test for var (#19596)
Summary: Add int16 activation / int8 weight (a16w8) quantization tests for `aten.var` on Ethos-U55 and Ethos-U85. ## Changes - Add `test_parameters_ethosu` class attribute to `Var` with 2 test configurations (4D tensors with correction=0 and correction=1) - Switch existing `test_var_dim_u55_INT_no_dim` and `test_var_dim_u85_INT_no_dim` from `Var.test_parameters` to `Var.test_parameters_ethosu` for Ethos-U compatible tensor shapes - Add `test_var_a16w8_u55_INT` using `EthosU55PipelineINT` with `a16w8_quantization=True, symmetric_io_quantization=True` - Add `test_var_a16w8_u85_INT` using `EthosU85PipelineINT` with same kwargs - Register `ops/test_var.py` in `fbcode/` and `xplat/` `targets.bzl` Differential Revision: D104532362
1 parent d1db6b7 commit 8805beb

2 files changed

Lines changed: 38 additions & 2 deletions

File tree

backends/arm/test/ops/test_var.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ class Var(torch.nn.Module):
3232
),
3333
}
3434

35+
test_parameters_ethosu = {
36+
"var_4d_keep_dim_0_correction": lambda: (torch.randn(1, 50, 10, 20), True, 0),
37+
"var_4d_keep_dim_1_correction": lambda: (torch.randn(1, 30, 15, 20), True, 1),
38+
}
39+
3540
def __init__(self, keepdim: bool = True, correction: int = 0):
3641
super().__init__()
3742
self.keepdim = keepdim
@@ -170,7 +175,7 @@ def test_var_dim_tosa_INT_no_dim(test_data: Tuple):
170175
pipeline.run()
171176

172177

173-
@common.parametrize("test_data", Var.test_parameters)
178+
@common.parametrize("test_data", Var.test_parameters_ethosu)
174179
@common.XfailIfNoCorstone300
175180
def test_var_dim_u55_INT_no_dim(test_data: Tuple):
176181
test_data, keepdim, correction = test_data()
@@ -183,7 +188,7 @@ def test_var_dim_u55_INT_no_dim(test_data: Tuple):
183188
pipeline.run()
184189

185190

186-
@common.parametrize("test_data", Var.test_parameters)
191+
@common.parametrize("test_data", Var.test_parameters_ethosu)
187192
@common.XfailIfNoCorstone320
188193
def test_var_dim_u85_INT_no_dim(test_data: Tuple):
189194
test_data, keepdim, correction = test_data()
@@ -224,6 +229,36 @@ def test_var_dim_vgf_quant_no_dim(test_data: Tuple):
224229
pipeline.run()
225230

226231

232+
@common.parametrize("test_data", Var.test_parameters_ethosu)
233+
@common.XfailIfNoCorstone300
234+
def test_var_a16w8_u55_INT(test_data: Tuple):
235+
test_data, keepdim, correction = test_data()
236+
pipeline = EthosU55PipelineINT[input_t1](
237+
Var(keepdim, correction),
238+
(test_data,),
239+
aten_ops=[],
240+
exir_ops=[],
241+
a16w8_quantization=True,
242+
symmetric_io_quantization=True,
243+
)
244+
pipeline.run()
245+
246+
247+
@common.parametrize("test_data", Var.test_parameters_ethosu)
248+
@common.XfailIfNoCorstone320
249+
def test_var_a16w8_u85_INT(test_data: Tuple):
250+
test_data, keepdim, correction = test_data()
251+
pipeline = EthosU85PipelineINT[input_t1](
252+
Var(keepdim, correction),
253+
(test_data,),
254+
aten_ops=[],
255+
exir_ops=[],
256+
a16w8_quantization=True,
257+
symmetric_io_quantization=True,
258+
)
259+
pipeline.run()
260+
261+
227262
#############
228263
## VarDim ###
229264
#############

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def define_arm_tests():
3939
"ops/test_exp.py",
4040
"ops/test_reciprocal.py",
4141
"ops/test_mean_dim.py",
42+
"ops/test_var.py",
4243
]
4344

4445
# Quantization

0 commit comments

Comments
 (0)