Skip to content

Commit a56c81d

Browse files
Ninja91facebook-github-bot
authored andcommitted
Add MaxPool1D decomposition pass support (#17022)
Summary: Implement DecomposeMaxPool1dPass to enable MaxPool1D support on ARM backend by decomposing max_pool1d to view_copy → max_pool2d → view_copy. ## Implementation Strategy ### Decomposition Approach (Optimal for TOSA/Vela) The pass decomposes max_pool1d into max_pool2d via view_copy operations: 1. view_copy: (N, C, L) → (N, C, 1, L) - add height dimension 2. max_pool2d: with adapted params [k]→[1,k], [s]→[1,s], [p]→[0,p] 3. view_copy: (N, C, 1, L_out) → (N, C, L_out) - remove height dimension ### Why This Approach is Optimal 1. **view_copy maps to TOSA RESHAPE** which is zero-cost in Vela: - Classified as memory_only_ops (Reshape, Squeeze, ExpandDims, Identity) - Bypassed entirely when conditions met (NPU-produced, single consumer) - Tensor equivalence enables memory aliasing (same address) 2. **TFA Pipeline Placement (before quantization)**: - view_copy is in _one_to_one_shared_input_qspec (line 407) - max_pool2d is in _one_to_one_shared_input_or_input_act_qspec (line 455) - Both get proper SharedQuantizationSpec from annotator automatically 3. **Quantization Handling**: - Clear qparams on intermediate view_copy ops (let annotator fill them) - Preserve original meta on max_pool2d for proper tracing - MAX_POOL2D doesn't need zero-point handling (unlike AVG_POOL2D) ### TOSA/Vela Constraints Validated - U55: Stride ≤3 ✓, Kernel ≤256x256 ✓ - U85: Extended stride support via accumulator save/restore - Dilation: Handled by separate DecomposeMaxPool2dPass if needed Reviewed By: 3l1 Differential Revision: D91760459
1 parent 37b675c commit a56c81d

4 files changed

Lines changed: 109 additions & 15 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from .decompose_logit_pass import DecomposeLogitPass # noqa
6767
from .decompose_masked_fill_pass import DecomposeMaskedFillPass # noqa
6868
from .decompose_matmul import DecomposeMatmulPass # noqa
69+
from .decompose_max_pool1d_pass import DecomposeMaxPool1dPass # noqa
6970
from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa
7071
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
7172
from .decompose_ne_pass import DecomposeNotEqualPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
DecomposeLogitPass,
6969
DecomposeMaskedFillPass,
7070
DecomposeMatmulPass,
71+
DecomposeMaxPool1dPass,
7172
DecomposeMaxPool2dPass,
7273
DecomposeMeanDimPass,
7374
DecomposeNotEqualPass,
@@ -343,6 +344,7 @@ def _tosa_pipeline(
343344
DecomposeCumsumPass(exported_program),
344345
DecomposeAsStridedCopyPass(),
345346
DecomposeMaxPool2dPass(),
347+
DecomposeMaxPool1dPass(),
346348
SizeAdjustInputPass(),
347349
DecomposeSelectPass(),
348350
ConvertSqueezesToViewPass(),
@@ -447,6 +449,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
447449
DecomposeLinalgVectorNormPass(tfa_pass=True),
448450
DecomposeSqrtPass(tfa_pass=True),
449451
DecomposeAvgPool2dPass(tfa_pass=True),
452+
DecomposeMaxPool1dPass(tfa_pass=True),
450453
DecomposeSoftmaxUnstablePass(tfa_pass=True),
451454
DecomposeSoftmaxPass(tfa_pass=True),
452455
ConvertMinMaxPass(tfa_pass=True),
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
from typing import List, Optional, Set, Type, Union
9+
10+
import torch
11+
from executorch.backends.arm._passes.arm_pass import ArmPass
12+
from executorch.exir.pass_base import ExportPass
13+
14+
15+
def _normalize_to_list(
16+
value: Optional[Union[int, List[int], tuple]],
17+
default: Optional[List[int]] = None,
18+
) -> List[int]:
19+
"""Normalize parameter to list: handle None, int, tuple, list."""
20+
if value is None:
21+
if default is None:
22+
raise ValueError("Value cannot be None without a default")
23+
return default
24+
if isinstance(value, int):
25+
return [value]
26+
return list(value)
27+
28+
29+
class DecomposeMaxPool1dPass(ArmPass):
30+
"""
31+
Decomposes max_pool1d into max_pool2d via unsqueeze_copy/squeeze_copy operations.
32+
33+
This pass runs in transform_for_annotation (TFA) pipeline before quantization,
34+
ensuring proper quantization annotation for the decomposed ops.
35+
36+
Transformation:
37+
max_pool1d(x, kernel, stride, padding, dilation, ceil_mode)
38+
→ unsqueeze_copy(x, dim=2) # (N,C,L) → (N,C,1,L)
39+
→ max_pool2d(..., [1,k], [1,s], [0,p], [1,d], ceil_mode)
40+
→ squeeze_copy(..., dims=[2]) # (N,C,1,L') → (N,C,L')
41+
"""
42+
43+
_passes_required_after: Set[Type[ExportPass]] = set()
44+
45+
def call_operator(self, op, args, kwargs, meta):
46+
if op != torch.ops.aten.max_pool1d.default or not self.allowed_to_transform(
47+
meta
48+
):
49+
return super().call_operator(op, args, kwargs, meta)
50+
51+
# Extract and normalize arguments
52+
x = args[0]
53+
kernel_size = _normalize_to_list(args[1])
54+
stride = _normalize_to_list(
55+
args[2] if len(args) > 2 else None,
56+
default=kernel_size, # stride defaults to kernel_size
57+
)
58+
padding = _normalize_to_list(args[3] if len(args) > 3 else 0)
59+
dilation = _normalize_to_list(args[4] if len(args) > 4 else 1)
60+
ceil_mode = args[5] if len(args) > 5 else False
61+
62+
# Step 1: Unsqueeze input from 3D to 4D at dim=2
63+
# (N, C, L) → (N, C, 1, L)
64+
x_4d = super().call_operator(
65+
torch.ops.aten.unsqueeze_copy.default,
66+
(x, 2),
67+
{},
68+
meta,
69+
updated=True,
70+
)
71+
72+
# Step 2: Call max_pool2d with 2D parameters
73+
# kernel: [k] → [1, k], stride: [s] → [1, s]
74+
# padding: [p] → [0, p], dilation: [d] → [1, d]
75+
pooled = super().call_operator(
76+
torch.ops.aten.max_pool2d.default,
77+
(
78+
x_4d,
79+
[1] + kernel_size,
80+
[1] + stride,
81+
[0] + padding,
82+
[1] + dilation,
83+
ceil_mode,
84+
),
85+
{},
86+
meta,
87+
updated=True,
88+
)
89+
90+
# Step 3: Squeeze output back to 3D at dims=[2]
91+
# (N, C, 1, L') → (N, C, L')
92+
output = super().call_operator(
93+
torch.ops.aten.squeeze_copy.dims,
94+
(pooled, [2]),
95+
{},
96+
meta,
97+
updated=True,
98+
)
99+
100+
return output

backends/arm/quantizer/quantization_annotator.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55
"""Provide quantization annotation logic for Arm backends.
66
7-
This module computes per-node quantization properties and applies input/output
8-
annotations to FX graphs using TorchAO qspecs.
9-
7+
This module computes per-node quantization properties and applies
8+
input/output annotations to FX graphs using TorchAO qspecs.
109
"""
1110

1211
import logging
@@ -57,7 +56,6 @@ class _OpQuantProperties:
5756
indexed by argument positions.
5857
quant_output (Optional[_QuantProperty]): Quantization spec for the
5958
node's output when applicable.
60-
6159
"""
6260

6361
def __init__(self):
@@ -73,7 +71,6 @@ def _as_list(x):
7371
7472
Returns:
7573
list: ``x`` if already a list; otherwise ``[x]``.
76-
7774
"""
7875
if isinstance(x, (list, tuple)):
7976
return x
@@ -122,7 +119,6 @@ def _is_ok_for_quantization(
122119
123120
Returns:
124121
bool: `True` if the node can be quantized, otherwise `False`.
125-
126122
"""
127123
# Check output
128124
if quant_properties.quant_output is not None:
@@ -182,7 +178,6 @@ def _get_node_target(module: torch.nn.Module | torch.fx.GraphModule, target_str:
182178
183179
Returns:
184180
Any: Resolved attribute on the module.
185-
186181
"""
187182
targets = target_str.split(".")
188183
for target in targets[:-1]:
@@ -195,7 +190,6 @@ def _is_large_scalar(node: Node, gm: torch.fx.GraphModule):
195190
196191
Large scalars are skipped because ``torch.histc`` supports values only up
197192
to a certain upper bound.
198-
199193
"""
200194
HISTC_UPPER_BOUND = 3.4028235e15
201195
if node.op == "get_attr" and isinstance(node.target, str):
@@ -213,7 +207,8 @@ def _is_large_scalar(node: Node, gm: torch.fx.GraphModule):
213207

214208

215209
def _is_non_float_tensor(node: Node) -> bool:
216-
"""Check if the output of a node has a data type other than `torch.float32`.
210+
"""Check if the output of a node has a data type other than
211+
`torch.float32`.
217212
218213
If the output is not `torch.float32`, quantization cannot be performed, as
219214
observers only work with floating-point tensors.
@@ -230,7 +225,6 @@ def _is_non_float_tensor(node: Node) -> bool:
230225
`torch.float32` as its data type.
231226
- If node.meta["val"] is missing or is not an instance of `FakeTensor`,
232227
the function returns True.
233-
234228
"""
235229
if "val" in node.meta and isinstance(node.meta["val"], Sequence):
236230
return any(
@@ -258,7 +252,6 @@ def _annotate_input(node: Node, quant_property: _QuantProperty):
258252
Raises:
259253
RuntimeError: If the node is already annotated.
260254
TypeError: If an input argument is not a ``Node`` instance.
261-
262255
"""
263256
if is_annotated(node):
264257
raise RuntimeError(
@@ -295,7 +288,6 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
295288
RuntimeError: If the node is already annotated.
296289
ValueError: If ``mark_annotated`` is True, ``optional`` is True, or
297290
``index`` is not zero.
298-
299291
"""
300292
if is_annotated(node):
301293
raise RuntimeError(
@@ -322,7 +314,6 @@ def _match_pattern(
322314
``pattern``. If ``filter_fn`` is provided, require all nodes in the chain
323315
to pass the filter. Each pattern element is a list of disjunctive node
324316
targets.
325-
326317
"""
327318
if len(pattern) < 1:
328319
raise ValueError("No pattern provided")
@@ -408,6 +399,7 @@ def _match_pattern(
408399
torch.ops.aten.squeeze_copy.default,
409400
torch.ops.aten.squeeze_copy.dim,
410401
torch.ops.aten.squeeze_.dim,
402+
torch.ops.aten.squeeze_copy.dims,
411403
torch.ops.aten.squeeze.dim,
412404
torch.ops.aten.squeeze.dims,
413405
torch.ops.aten.unbind.int,
@@ -503,7 +495,6 @@ def get_quant_properties( # noqa: C901
503495
Returns:
504496
_OpQuantProperties | None: Properties to apply, or ``None`` if the
505497
node is unsupported or not suitable for quantization.
506-
507498
"""
508499
if node.target == torch.ops.aten.conv_transpose2d.input:
509500
weight_qspec = _adjust_weight_qspec_for_conv_transpose(
@@ -820,7 +811,6 @@ def annotate_graph( # type: ignore[return]
820811
821812
Returns:
822813
Optional[List[List[Node]]]: Reserved for future use; currently None.
823-
824814
"""
825815
for node in gm.graph.nodes:
826816
if node.op != "call_function":

0 commit comments

Comments
 (0)