Skip to content

Commit 5cc05a5

Browse files
Nitin Jainmeta-codesync[bot]
authored andcommitted
Add MaxPool1D operator tests (xfail - not yet supported)
Summary: Add unit tests for MaxPool1D operation on ARM backend to demonstrate current lack of support. Tests are marked as xfail and expected to fail. ## Background Research ### PyTorch Native max_pool1d Decomposition PyTorch's native max_pool1d decomposes to max_pool2d via unsqueeze/squeeze: - `max_pool1d_with_indices` → `unsqueeze(-2)` → `max_pool2d_with_indices` → `squeeze(-2)` - Parameter mapping: kernel [k] → [1,k], stride [s] → [1,s], padding [p] → [0,p] ### Quantization Handling - MAX_POOL2D does NOT require zero-point handling (unlike AVG_POOL2D) - Max pooling simply selects the maximum value - quantization params preserved automatically - SharedQuantizationSpec pattern ensures input/output share qparams ### TOSA/Vela Constraints for Ethos-U - U55: Stride 1-3 per dimension, kernel product ≤ 65536, kernel height ≤ 256 - U85: Additional support for larger stride decomposition - TOSA RESHAPE operations are zero-cost in Vela when: - NPU-produced IFM with single consumer - Not graph I/O - Same memory area (tensor equivalence enables memory aliasing) Differential Revision: D91760446
1 parent 7caf9b2 commit 5cc05a5

2 files changed

Lines changed: 165 additions & 0 deletions

File tree

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
"""
9+
Tests for the max_pool1d operation.
10+
11+
In PyTorch, max_pool1d may be decomposed internally into a sequence of
12+
operations (e.g., unsqueeze -> max_pool2d_with_indices -> getitem -> squeeze),
13+
but this test focuses on ensuring that the max_pool1d aten op is correctly
14+
lowered/quantized and delegated to the expected edge dialect op on the
15+
Arm backend (U55/U85).
16+
"""
17+
18+
from typing import Callable, Tuple
19+
20+
import pytest
21+
22+
import torch
23+
24+
from executorch.backends.arm.test import common
25+
26+
from executorch.backends.arm.test.tester.test_pipeline import (
27+
EthosU55PipelineINT,
28+
EthosU85PipelineINT,
29+
TosaPipelineFP,
30+
TosaPipelineINT,
31+
VgfPipeline,
32+
)
33+
34+
input_t1 = Tuple[torch.Tensor]
35+
36+
37+
class MaxPool1d(torch.nn.Module):
38+
def __init__(
39+
self,
40+
kernel_size: int,
41+
stride: int = 1,
42+
padding: int = 0,
43+
):
44+
super().__init__()
45+
self.max_pool_1d = torch.nn.MaxPool1d(
46+
kernel_size=kernel_size,
47+
stride=stride,
48+
padding=padding,
49+
)
50+
51+
def forward(self, x):
52+
return self.max_pool_1d(x)
53+
54+
55+
# Test data suite for single-batch tests (N=1), suitable for all targets
56+
test_data_suite = {
57+
# (test_name, test_data, [kernel_size, stride, padding])
58+
"simple": lambda: (torch.rand(1, 16, 50), [4, 2, 0]),
59+
"with_padding": lambda: (torch.rand(1, 16, 50), [3, 2, 1]),
60+
"stride_1": lambda: (torch.rand(1, 8, 32), [3, 1, 0]),
61+
"larger_kernel": lambda: (torch.rand(1, 4, 64), [8, 4, 0]),
62+
}
63+
64+
# Multi-batch test data (N>1) - not supported on U55 due to N==1 constraint
65+
test_data_suite_multi_batch = {
66+
"multi_batch": lambda: (torch.rand(4, 16, 50), [4, 2, 0]),
67+
}
68+
69+
# Combined suite for targets that support multi-batch (TOSA, U85, VGF)
70+
test_data_suite_all = {**test_data_suite, **test_data_suite_multi_batch}
71+
72+
# After PyTorch decomposition, max_pool1d becomes max_pool2d_with_indices
73+
# After to_edge, becomes max_pool2d_with_indices in edge dialect
74+
aten_op = "torch.ops.aten.max_pool1d.default"
75+
exir_op = "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default"
76+
77+
78+
@common.parametrize("test_data", test_data_suite_all)
79+
@pytest.mark.xfail(reason="MaxPool1D not yet supported", strict=False)
80+
def test_max_pool1d_tosa_FP(test_data: Callable):
81+
"""Test max_pool1d with TOSA FP pipeline."""
82+
test_data, model_params = test_data()
83+
pipeline = TosaPipelineFP[input_t1](
84+
MaxPool1d(*model_params),
85+
(test_data,),
86+
aten_op,
87+
exir_op,
88+
)
89+
pipeline.run()
90+
91+
92+
@common.parametrize("test_data", test_data_suite_all)
93+
@pytest.mark.xfail(reason="MaxPool1D not yet supported", strict=False)
94+
def test_max_pool1d_tosa_INT(test_data: Callable):
95+
"""Test max_pool1d with TOSA INT pipeline (quantized)."""
96+
test_data, model_params = test_data()
97+
pipeline = TosaPipelineINT[input_t1](
98+
MaxPool1d(*model_params),
99+
(test_data,),
100+
aten_op,
101+
exir_op,
102+
)
103+
pipeline.run()
104+
105+
106+
@common.parametrize("test_data", test_data_suite)
107+
@common.XfailIfNoCorstone300
108+
@pytest.mark.xfail(reason="MaxPool1D not yet supported", strict=False)
109+
def test_max_pool1d_u55_INT(test_data: Callable):
110+
"""Test max_pool1d on Ethos-U55 (quantized)."""
111+
test_data, model_params = test_data()
112+
pipeline = EthosU55PipelineINT[input_t1](
113+
MaxPool1d(*model_params),
114+
(test_data,),
115+
aten_op,
116+
exir_ops=[],
117+
)
118+
pipeline.run()
119+
120+
121+
@common.parametrize("test_data", test_data_suite_all)
122+
@common.XfailIfNoCorstone320
123+
@pytest.mark.xfail(reason="MaxPool1D not yet supported", strict=False)
124+
def test_max_pool1d_u85_INT(test_data: Callable):
125+
"""Test max_pool1d on Ethos-U85 (quantized)."""
126+
test_data, model_params = test_data()
127+
pipeline = EthosU85PipelineINT[input_t1](
128+
MaxPool1d(*model_params),
129+
(test_data,),
130+
aten_op,
131+
exir_ops=[],
132+
)
133+
pipeline.run()
134+
135+
136+
# VGF tests
137+
@common.parametrize("test_data", test_data_suite_all)
138+
@common.SkipIfNoModelConverter
139+
def test_max_pool1d_vgf_no_quant(test_data: Callable):
140+
"""Test max_pool1d with VGF pipeline (non-quantized)."""
141+
test_data, model_params = test_data()
142+
pipeline = VgfPipeline[input_t1](
143+
MaxPool1d(*model_params),
144+
(test_data,),
145+
aten_op,
146+
exir_op,
147+
quantize=False,
148+
)
149+
pipeline.run()
150+
151+
152+
@common.parametrize("test_data", test_data_suite_all)
153+
@common.SkipIfNoModelConverter
154+
def test_max_pool1d_vgf_quant(test_data: Callable):
155+
"""Test max_pool1d with VGF pipeline (quantized)."""
156+
test_data, model_params = test_data()
157+
pipeline = VgfPipeline[input_t1](
158+
MaxPool1d(*model_params),
159+
(test_data,),
160+
aten_op,
161+
exir_op,
162+
quantize=True,
163+
)
164+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def define_arm_tests():
2020
"ops/test_cat.py",
2121
"ops/test_conv2d.py",
2222
"ops/test_linear.py",
23+
"ops/test_max_pool1d.py",
2324
"ops/test_mul.py",
2425
"ops/test_permute.py",
2526
"ops/test_rsqrt.py",

0 commit comments

Comments
 (0)