Skip to content

Commit 46190fb

Browse files
Ninja91facebook-github-bot
authored andcommitted
Add MaxPool1D decomposition pass support (#17022)
Summary: Implement DecomposeMaxPool1dPass to enable MaxPool1D support on ARM backend by decomposing max_pool1d to view_copy → max_pool2d → view_copy. ## Implementation Strategy ### Decomposition Approach (Optimal for TOSA/Vela) The pass decomposes max_pool1d into max_pool2d via view_copy operations: 1. view_copy: (N, C, L) → (N, C, 1, L) - add height dimension 2. max_pool2d: with adapted params [k]→[1,k], [s]→[1,s], [p]→[0,p] 3. view_copy: (N, C, 1, L_out) → (N, C, L_out) - remove height dimension ### Why This Approach is Optimal 1. **view_copy maps to TOSA RESHAPE** which is zero-cost in Vela: - Classified as memory_only_ops (Reshape, Squeeze, ExpandDims, Identity) - Bypassed entirely when conditions met (NPU-produced, single consumer) - Tensor equivalence enables memory aliasing (same address) 2. **TFA Pipeline Placement (before quantization)**: - view_copy is in _one_to_one_shared_input_qspec (line 407) - max_pool2d is in _one_to_one_shared_input_or_input_act_qspec (line 455) - Both get proper SharedQuantizationSpec from annotator automatically 3. **Quantization Handling**: - Clear qparams on intermediate view_copy ops (let annotator fill them) - Preserve original meta on max_pool2d for proper tracing - MAX_POOL2D doesn't need zero-point handling (unlike AVG_POOL2D) ### TOSA/Vela Constraints Validated - U55: Stride ≤3 ✓, Kernel ≤256x256 ✓ - U85: Extended stride support via accumulator save/restore - Dilation: Handled by separate DecomposeMaxPool2dPass if needed Differential Revision: D91760459
1 parent e7ef74e commit 46190fb

4 files changed

Lines changed: 104 additions & 1 deletion

File tree

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from .decompose_logit_pass import DecomposeLogitPass # noqa
6666
from .decompose_masked_fill_pass import DecomposeMaskedFillPass # noqa
6767
from .decompose_matmul import DecomposeMatmulPass # noqa
68+
from .decompose_max_pool1d_pass import DecomposeMaxPool1dPass # noqa
6869
from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa
6970
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
7071
from .decompose_ne_pass import DecomposeNotEqualPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
DecomposeLogitPass,
6767
DecomposeMaskedFillPass,
6868
DecomposeMatmulPass,
69+
DecomposeMaxPool1dPass,
6970
DecomposeMaxPool2dPass,
7071
DecomposeMeanDimPass,
7172
DecomposeNotEqualPass,
@@ -436,6 +437,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
436437
DecomposeLinalgVectorNormPass(tfa_pass=True),
437438
DecomposeSqrtPass(tfa_pass=True),
438439
DecomposeAvgPool2dPass(tfa_pass=True),
440+
DecomposeMaxPool1dPass(tfa_pass=True),
439441
DecomposeSoftmaxUnstablePass(tfa_pass=True),
440442
DecomposeSoftmaxPass(tfa_pass=True),
441443
ConvertMinMaxPass(tfa_pass=True),
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
from typing import List, Optional, Set, Type, Union
9+
10+
import torch
11+
from executorch.backends.arm._passes.arm_pass import ArmPass
12+
from executorch.exir.pass_base import ExportPass
13+
14+
15+
def _normalize_to_list(
16+
value: Optional[Union[int, List[int], tuple]],
17+
default: Optional[List[int]] = None,
18+
) -> List[int]:
19+
"""Normalize parameter to list: handle None, int, tuple, list."""
20+
if value is None:
21+
if default is None:
22+
raise ValueError("Value cannot be None without a default")
23+
return default
24+
if isinstance(value, int):
25+
return [value]
26+
return list(value)
27+
28+
29+
class DecomposeMaxPool1dPass(ArmPass):
30+
"""
31+
Decomposes max_pool1d into max_pool2d via unsqueeze/squeeze operations.
32+
33+
This pass runs in transform_for_annotation (TFA) pipeline before quantization,
34+
ensuring proper quantization annotation for the decomposed ops.
35+
36+
Transformation:
37+
max_pool1d(x, kernel, stride, padding, dilation, ceil_mode)
38+
→ unsqueeze_copy(x, dim=2) # (N,C,L) → (N,C,1,L)
39+
→ max_pool2d(..., [1,k], [1,s], [0,p], [1,d], ceil_mode)
40+
→ squeeze_copy(..., dims=[2]) # (N,C,1,L') → (N,C,L')
41+
"""
42+
43+
_passes_required_after: Set[Type[ExportPass]] = set()
44+
45+
def call_operator(self, op, args, kwargs, meta):
46+
if op != torch.ops.aten.max_pool1d.default or not self.allowed_to_transform(
47+
meta
48+
):
49+
return super().call_operator(op, args, kwargs, meta)
50+
51+
# Extract and normalize arguments
52+
x = args[0]
53+
kernel_size = _normalize_to_list(args[1])
54+
stride = _normalize_to_list(
55+
args[2] if len(args) > 2 else None,
56+
default=kernel_size, # stride defaults to kernel_size
57+
)
58+
padding = _normalize_to_list(args[3] if len(args) > 3 else 0)
59+
dilation = _normalize_to_list(args[4] if len(args) > 4 else 1)
60+
ceil_mode = args[5] if len(args) > 5 else False
61+
62+
# Step 1: Unsqueeze input from 3D to 4D at dim=2
63+
# (N, C, L) → (N, C, 1, L)
64+
x_4d = super().call_operator(
65+
torch.ops.aten.unsqueeze_copy.default,
66+
(x, 2),
67+
{},
68+
meta,
69+
updated=True,
70+
)
71+
72+
# Step 2: Call max_pool2d with 2D parameters
73+
# kernel: [k] → [1, k], stride: [s] → [1, s]
74+
# padding: [p] → [0, p], dilation: [d] → [1, d]
75+
pooled = super().call_operator(
76+
torch.ops.aten.max_pool2d.default,
77+
(
78+
x_4d,
79+
[1] + kernel_size,
80+
[1] + stride,
81+
[0] + padding,
82+
[1] + dilation,
83+
ceil_mode,
84+
),
85+
{},
86+
meta,
87+
updated=True,
88+
)
89+
90+
# Step 3: Squeeze output back to 3D at dims=[2]
91+
# (N, C, 1, L') → (N, C, L')
92+
output = super().call_operator(
93+
torch.ops.aten.squeeze_copy.dims,
94+
(pooled, [2]),
95+
{},
96+
meta,
97+
updated=True,
98+
)
99+
100+
return output

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def _match_pattern(
407407
torch.ops.aten.squeeze.default,
408408
torch.ops.aten.squeeze_copy.default,
409409
torch.ops.aten.squeeze_copy.dim,
410-
torch.ops.aten.squeeze_.dim,
410+
torch.ops.aten.squeeze_copy.dims,
411411
torch.ops.aten.squeeze.dim,
412412
torch.ops.aten.squeeze.dims,
413413
torch.ops.aten.unbind.int,

0 commit comments

Comments
 (0)