Skip to content

Commit ad1f626

Browse files
Ninja91facebook-github-bot
authored andcommitted
Add MaxPool1D decomposition pass support (#17022)
Summary: Implement DecomposeMaxPool1dPass to enable MaxPool1D support on ARM backend by decomposing max_pool1d into unsqueeze_copy → max_pool2d → squeeze_copy. ## Implementation Strategy ### Decomposition Approach (Optimal for TOSA/Vela) The pass decomposes max_pool1d into max_pool2d via unsqueeze_copy/squeeze_copy operations: 1. unsqueeze_copy(dim=2): (N, C, L) → (N, C, 1, L) - add height dimension 2. max_pool2d: with adapted params [k]→[1,k], [s]→[1,s], [p]→[0,p], [d]→[1,d] 3. squeeze_copy(dims=[2]): (N, C, 1, L_out) → (N, C, L_out) - remove height dimension ### Why This Approach is Optimal 1. **unsqueeze_copy and squeeze_copy map to TOSA RESHAPE** which is zero-cost in Vela: - Classified as memory_only_ops (Reshape, Squeeze, ExpandDims, Identity) - Bypassed entirely when conditions met (NPU-produced, single consumer) - Tensor equivalence enables memory aliasing (same address) 2. **TFA Pipeline Placement (before quantization)**: - unsqueeze_copy.default is in _one_to_one_shared_input_qspec - squeeze_copy.dims is added to _one_to_one_shared_input_qspec - max_pool2d is in _one_to_one_shared_input_or_input_act_qspec - All get proper SharedQuantizationSpec from the annotator automatically 3. **Quantization Handling**: - Clear qparams on intermediate unsqueeze_copy and squeeze_copy ops (let annotator fill them) - Preserve original meta on max_pool2d for proper tracing - MAX_POOL2D doesn't need zero-point handling (unlike AVG_POOL2D) ### TOSA/Vela Constraints Validated - U55: Stride ≤3 ✓, Kernel ≤256x256 ✓ - U85: Extended stride support via accumulator save/restore - Dilation: Handled by separate DecomposeMaxPool2dPass if needed Reviewed By: 3l1 Differential Revision: D91760459
1 parent 10c8958 commit ad1f626

5 files changed

Lines changed: 117 additions & 16 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
from .decompose_lstm_pass import DecomposeLstmPass # noqa
7171
from .decompose_masked_fill_pass import DecomposeMaskedFillPass # noqa
7272
from .decompose_matmul import DecomposeMatmulPass # noqa
73+
from .decompose_max_pool1d_pass import DecomposeMaxPool1dPass # noqa
7374
from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa
7475
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
7576
from .decompose_ne_pass import DecomposeNotEqualPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
DecomposeLstmPass,
7373
DecomposeMaskedFillPass,
7474
DecomposeMatmulPass,
75+
DecomposeMaxPool1dPass,
7576
DecomposeMaxPool2dPass,
7677
DecomposeMeanDimPass,
7778
DecomposeNotEqualPass,
@@ -506,6 +507,7 @@ def _tosa_pipeline(
506507
UnsqueezeBeforeRepeatPass(),
507508
DecomposeCumsumPass(exported_program),
508509
DecomposeAsStridedCopyPass(),
510+
DecomposeMaxPool1dPass(),
509511
DecomposeMaxPool2dPass(),
510512
SizeAdjustInputPass(),
511513
RewriteAvgPool2dPass(),
@@ -638,6 +640,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
638640
DecomposeDivPass(tfa_pass=True),
639641
DecomposeLinalgVectorNormPass(tfa_pass=True),
640642
DecomposeSqrtPass(tfa_pass=True),
643+
DecomposeMaxPool1dPass(tfa_pass=True),
641644
DecomposeSoftmaxPass(
642645
tfa_pass=True,
643646
),
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
from typing import List, Optional, Set, Type, Union
9+
10+
import torch
11+
from executorch.backends.arm._passes.arm_pass import ArmPass
12+
from executorch.exir.pass_base import ExportPass
13+
14+
15+
def _normalize_to_list(
16+
value: Optional[Union[int, List[int], tuple]],
17+
default: Optional[List[int]] = None,
18+
) -> List[int]:
19+
"""Normalize parameter to list: handle None, int, tuple, list."""
20+
if value is None:
21+
if default is None:
22+
raise ValueError("Value cannot be None without a default")
23+
return default
24+
if isinstance(value, int):
25+
return [value]
26+
return list(value)
27+
28+
29+
class DecomposeMaxPool1dPass(ArmPass):
30+
"""Decomposes max_pool1d into max_pool2d via unsqueeze_copy/squeeze_copy
31+
operations.
32+
33+
This pass runs in transform_for_annotation (TFA) pipeline before quantization,
34+
ensuring proper quantization annotation for the decomposed ops.
35+
36+
Transformation:
37+
max_pool1d(x, kernel, stride, padding, dilation, ceil_mode)
38+
→ unsqueeze_copy(x, dim=2) # (N,C,L) → (N,C,1,L)
39+
→ max_pool2d(..., [1,k], [1,s], [0,p], [1,d], ceil_mode)
40+
→ squeeze_copy(..., dims=[2]) # (N,C,1,L') → (N,C,L')
41+
"""
42+
43+
_passes_required_after: Set[Type[ExportPass]] = set()
44+
45+
def call_operator(self, op, args, kwargs, meta):
46+
if op != torch.ops.aten.max_pool1d.default or not self.allowed_to_transform(
47+
meta
48+
):
49+
return super().call_operator(op, args, kwargs, meta)
50+
51+
# Extract and normalize arguments
52+
x = args[0]
53+
kernel_size = _normalize_to_list(args[1])
54+
stride = _normalize_to_list(
55+
args[2] if len(args) > 2 else None,
56+
default=kernel_size, # stride defaults to kernel_size
57+
)
58+
padding = _normalize_to_list(args[3] if len(args) > 3 else 0)
59+
dilation = _normalize_to_list(args[4] if len(args) > 4 else 1)
60+
ceil_mode = args[5] if len(args) > 5 else False
61+
62+
# Step 1: Unsqueeze input from 3D to 4D at dim=2
63+
# (N, C, L) → (N, C, 1, L)
64+
unsqueeze_meta = meta.copy()
65+
unsqueeze_meta.data["input_qparams"] = {}
66+
unsqueeze_meta.data["output_qparams"] = {}
67+
x_4d = super().call_operator(
68+
torch.ops.aten.unsqueeze_copy.default,
69+
(x, 2),
70+
{},
71+
unsqueeze_meta,
72+
updated=True,
73+
)
74+
75+
# Step 2: Call max_pool2d with 2D parameters
76+
# kernel: [k] → [1, k], stride: [s] → [1, s]
77+
# padding: [p] → [0, p], dilation: [d] → [1, d]
78+
pooled = super().call_operator(
79+
torch.ops.aten.max_pool2d.default,
80+
(
81+
x_4d,
82+
[1] + kernel_size,
83+
[1] + stride,
84+
[0] + padding,
85+
[1] + dilation,
86+
ceil_mode,
87+
),
88+
{},
89+
meta,
90+
updated=True,
91+
)
92+
93+
# Step 3: Squeeze output back to 3D at dims=[2]
94+
# (N, C, 1, L') → (N, C, L')
95+
squeeze_meta = meta.copy()
96+
squeeze_meta.data["input_qparams"] = {}
97+
squeeze_meta.data["output_qparams"] = {}
98+
output = super().call_operator(
99+
torch.ops.aten.squeeze_copy.dims,
100+
(pooled, [2]),
101+
{},
102+
squeeze_meta,
103+
updated=True,
104+
)
105+
106+
return output

backends/arm/quantizer/quantization_annotator.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55
"""Provide quantization annotation logic for Arm backends.
66
7-
This module computes per-node quantization properties and applies input/output
8-
annotations to FX graphs using TorchAO qspecs.
9-
7+
This module computes per-node quantization properties and applies
8+
input/output annotations to FX graphs using TorchAO qspecs.
109
"""
1110

1211
import functools
@@ -72,7 +71,6 @@ class _OpQuantProperties:
7271
indexed by argument positions.
7372
quant_output (Optional[_QuantProperty]): Quantization spec for the
7473
node's output when applicable.
75-
7674
"""
7775

7876
def __init__(self):
@@ -93,7 +91,6 @@ def _as_list(x):
9391
9492
Returns:
9593
list: ``x`` if already a list; otherwise ``[x]``.
96-
9794
"""
9895
if isinstance(x, (list, tuple)):
9996
return x
@@ -206,7 +203,6 @@ def _is_ok_for_quantization(
206203
207204
Returns:
208205
bool: `True` if the node can be quantized, otherwise `False`.
209-
210206
"""
211207
# Check output
212208
if quant_properties.quant_output is not None:
@@ -266,7 +262,6 @@ def _get_node_target(module: torch.nn.Module | torch.fx.GraphModule, target_str:
266262
267263
Returns:
268264
Any: Resolved attribute on the module.
269-
270265
"""
271266
targets = target_str.split(".")
272267
for target in targets[:-1]:
@@ -279,7 +274,6 @@ def _is_large_scalar(node: Node, gm: torch.fx.GraphModule):
279274
280275
Large scalars are skipped because ``torch.histc`` supports values only up
281276
to a certain upper bound.
282-
283277
"""
284278
HISTC_UPPER_BOUND = 3.4028235e15
285279
if node.op == "get_attr" and isinstance(node.target, str):
@@ -297,7 +291,8 @@ def _is_large_scalar(node: Node, gm: torch.fx.GraphModule):
297291

298292

299293
def _is_non_float_tensor(node: Node) -> bool:
300-
"""Check if the output of a node has a data type other than `torch.float32`.
294+
"""Check if the output of a node has a data type other than
295+
`torch.float32`.
301296
302297
If the output is not `torch.float32`, quantization cannot be performed, as
303298
observers only work with floating-point tensors.
@@ -314,7 +309,6 @@ def _is_non_float_tensor(node: Node) -> bool:
314309
`torch.float32` as its data type.
315310
- If node.meta["val"] is missing or is not an instance of `FakeTensor`,
316311
the function returns True.
317-
318312
"""
319313
if "val" in node.meta and isinstance(node.meta["val"], Sequence):
320314
return any(
@@ -342,7 +336,6 @@ def _annotate_input(node: Node, quant_property: _QuantProperty):
342336
Raises:
343337
RuntimeError: If the node is already annotated.
344338
TypeError: If an input argument is not a ``Node`` instance.
345-
346339
"""
347340
if is_annotated(node):
348341
raise RuntimeError(
@@ -379,7 +372,6 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
379372
RuntimeError: If the node is already annotated.
380373
ValueError: If ``mark_annotated`` is True, ``optional`` is True, or
381374
``index`` is not zero.
382-
383375
"""
384376
if is_annotated(node):
385377
raise RuntimeError(
@@ -408,7 +400,6 @@ def _match_pattern(
408400
``pattern``. If ``filter_fn`` is provided, require all nodes in the chain
409401
to pass the filter. Each pattern element is an iterable of disjunctive
410402
node targets.
411-
412403
"""
413404
if len(pattern) < 1:
414405
raise ValueError("No pattern provided")
@@ -517,6 +508,9 @@ def _match_pattern(
517508
torch.ops.aten.squeeze_copy.default,
518509
torch.ops.aten.squeeze_copy.dim,
519510
torch.ops.aten.squeeze_.dim,
511+
# DecomposeMaxPool1dPass emits squeeze_copy.dims as a view-like intermediate;
512+
# include here so it receives SharedQuantizationSpec from its input.
513+
torch.ops.aten.squeeze_copy.dims,
520514
torch.ops.aten.squeeze.dim,
521515
torch.ops.aten.squeeze.dims,
522516
torch.ops.aten.unbind.int,
@@ -612,7 +606,6 @@ def get_quant_properties( # noqa: C901
612606
Returns:
613607
_OpQuantProperties | None: Properties to apply, or ``None`` if the
614608
node is unsupported or not suitable for quantization.
615-
616609
"""
617610
if node.target == torch.ops.aten.conv_transpose2d.input:
618611
weight_qspec = _adjust_weight_qspec_for_conv_transpose(
@@ -950,7 +943,6 @@ def annotate_graph( # type: ignore[return]
950943
951944
Returns:
952945
Optional[List[List[Node]]]: Reserved for future use; currently None.
953-
954946
"""
955947
for node in gm.graph.nodes:
956948
if node.op != "call_function":

backends/arm/test/ops/test_max_pool1d.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def test_max_pool2d_tosa_FP_decomposed(test_data: Callable):
8888

8989

9090
@common.parametrize("test_data", test_data_suite_all)
91-
@pytest.mark.xfail(reason="MaxPool1D not yet supported", strict=False)
9291
def test_max_pool2d_tosa_INT_decomposed(test_data: Callable):
9392
"""Test max_pool1d with TOSA INT pipeline (quantized)."""
9493
test_data, model_params = test_data()

0 commit comments

Comments
 (0)