Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from .decompose_lstm_pass import DecomposeLstmPass # noqa
from .decompose_masked_fill_pass import DecomposeMaskedFillPass # noqa
from .decompose_matmul import DecomposeMatmulPass # noqa
from .decompose_max_pool1d_pass import DecomposeMaxPool1dPass # noqa
from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
Expand Down
3 changes: 3 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
DecomposeLstmPass,
DecomposeMaskedFillPass,
DecomposeMatmulPass,
DecomposeMaxPool1dPass,
DecomposeMaxPool2dPass,
DecomposeMeanDimPass,
DecomposeNotEqualPass,
Expand Down Expand Up @@ -506,6 +507,7 @@ def _tosa_pipeline(
UnsqueezeBeforeRepeatPass(),
DecomposeCumsumPass(exported_program),
DecomposeAsStridedCopyPass(),
DecomposeMaxPool1dPass(),

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with this one. We shouldn't see any max_pool1d:s at this stage, the pass won't match any exir-targets and will emit torch.ops rather than exir_ops.

DecomposeMaxPool2dPass(),
SizeAdjustInputPass(),
RewriteAvgPool2dPass(),
Expand Down Expand Up @@ -638,6 +640,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
DecomposeDivPass(tfa_pass=True),
DecomposeLinalgVectorNormPass(tfa_pass=True),
DecomposeSqrtPass(tfa_pass=True),
DecomposeMaxPool1dPass(tfa_pass=True),
DecomposeSoftmaxPass(
tfa_pass=True,
),
Expand Down
106 changes: 106 additions & 0 deletions backends/arm/_passes/decompose_max_pool1d_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Set, Type, Union

import torch
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.exir.pass_base import ExportPass


def _normalize_to_list(
value: Optional[Union[int, List[int], tuple]],
default: Optional[List[int]] = None,
) -> List[int]:
"""Normalize parameter to list: handle None, int, tuple, list."""
if value is None:
if default is None:
raise ValueError("Value cannot be None without a default")
return default
if isinstance(value, int):
return [value]
return list(value)


class DecomposeMaxPool1dPass(ArmPass):
"""Decomposes max_pool1d into max_pool2d via unsqueeze_copy/squeeze_copy
operations.

This pass runs in transform_for_annotation (TFA) pipeline before quantization,
ensuring proper quantization annotation for the decomposed ops.

Transformation:
max_pool1d(x, kernel, stride, padding, dilation, ceil_mode)
→ unsqueeze_copy(x, dim=2) # (N,C,L) → (N,C,1,L)
→ max_pool2d(..., [1,k], [1,s], [0,p], [1,d], ceil_mode)
→ squeeze_copy(..., dims=[2]) # (N,C,1,L') → (N,C,L')
"""
Comment on lines +29 to +41

Copilot AI Mar 3, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are existing max_pool1d backend tests currently marked xfail (e.g., backends/arm/test/ops/test_max_pool1d.py). Since this PR adds explicit MaxPool1D support, it should also update/enable those tests (and adjust any expected edge op name if the lowering changes from the previous max_pool2d_with_indices pattern).

Copilot uses AI. Check for mistakes.

_passes_required_after: Set[Type[ExportPass]] = set()

def call_operator(self, op, args, kwargs, meta):
if op != torch.ops.aten.max_pool1d.default or not self.allowed_to_transform(
meta
):
return super().call_operator(op, args, kwargs, meta)
Comment on lines +45 to +49

Copilot AI Mar 3, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DecomposeMaxPool1dPass currently only matches torch.ops.aten.max_pool1d.default. In the backend lowering pipeline (transform_to_backend_pipeline), graphs are typically in the EXIR edge dialect, so this pass may never fire for non-quantized compilation unless it also handles the edge-dialect overload (and emits the corresponding edge ops). Consider supporting both aten and edge overloads (similar to other passes that accept both).

Copilot uses AI. Check for mistakes.

# Extract and normalize arguments
x = args[0]
kernel_size = _normalize_to_list(args[1])
stride = _normalize_to_list(
args[2] if len(args) > 2 else None,
default=kernel_size, # stride defaults to kernel_size
)
padding = _normalize_to_list(args[3] if len(args) > 3 else 0)
dilation = _normalize_to_list(args[4] if len(args) > 4 else 1)
ceil_mode = args[5] if len(args) > 5 else False
Comment on lines +54 to +60

Copilot AI Feb 19, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value handling for stride on line 56 may not be correct. When stride is not provided (None), it should default to kernel_size. However, if len(args) > 2 and args[2] is explicitly None, the current code will pass None to _normalize_to_list, which will use the default. This is correct. But if args[2] contains an empty list or tuple, this won't be handled correctly.

Consider adding validation that kernel_size, stride, padding, and dilation after normalization all have exactly one element (for 1D operations), to catch any potential issues early.

Copilot uses AI. Check for mistakes.

# Step 1: Unsqueeze input from 3D to 4D at dim=2
# (N, C, L) → (N, C, 1, L)
unsqueeze_meta = meta.copy()
unsqueeze_meta.data["input_qparams"] = {}
unsqueeze_meta.data["output_qparams"] = {}
x_4d = super().call_operator(
torch.ops.aten.unsqueeze_copy.default,
(x, 2),
{},
unsqueeze_meta,
updated=True,
)
Comment on lines +67 to +73

Copilot AI Mar 3, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When inserting the unsqueeze_copy node, the pass reuses the original node metadata unchanged. In the main Arm pipeline this pass can run post Q/DQ folding, so carrying over input_qparams/output_qparams from the original max_pool1d node can incorrectly mark the new view-like node as already-quantized. Other passes that insert view ops clear qparams on these intermediates (e.g., Conv1dUnsqueezePass). Consider copying meta and clearing input/output qparams for the inserted unsqueeze node.

Copilot uses AI. Check for mistakes.

# Step 2: Call max_pool2d with 2D parameters
# kernel: [k] → [1, k], stride: [s] → [1, s]
# padding: [p] → [0, p], dilation: [d] → [1, d]
pooled = super().call_operator(
torch.ops.aten.max_pool2d.default,
(
x_4d,
[1] + kernel_size,
[1] + stride,
[0] + padding,
[1] + dilation,
ceil_mode,
),
{},
meta,
updated=True,
)
Comment on lines +45 to +91

# Step 3: Squeeze output back to 3D at dims=[2]
# (N, C, 1, L') → (N, C, L')
squeeze_meta = meta.copy()
squeeze_meta.data["input_qparams"] = {}
squeeze_meta.data["output_qparams"] = {}
output = super().call_operator(
torch.ops.aten.squeeze_copy.dims,
(pooled, [2]),
{},
squeeze_meta,
updated=True,
)
Comment on lines +98 to +104

Copilot AI Mar 3, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the inserted unsqueeze_copy: the inserted squeeze_copy node currently reuses the original node metadata unchanged. If this pass runs after quantization/QDQ folding, preserving qparams here can lead to incorrect quant metadata on a view-like op. Consider using a copied meta with cleared input_qparams/output_qparams for this intermediate reshape as well.

Copilot uses AI. Check for mistakes.

return output
22 changes: 7 additions & 15 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Provide quantization annotation logic for Arm backends.

This module computes per-node quantization properties and applies input/output
annotations to FX graphs using TorchAO qspecs.

This module computes per-node quantization properties and applies
input/output annotations to FX graphs using TorchAO qspecs.
"""

import functools
Expand Down Expand Up @@ -72,7 +71,6 @@
indexed by argument positions.
quant_output (Optional[_QuantProperty]): Quantization spec for the
node's output when applicable.

"""

def __init__(self):
Expand All @@ -93,7 +91,6 @@

Returns:
list: ``x`` if already a list; otherwise ``[x]``.

"""
if isinstance(x, (list, tuple)):
return x
Expand Down Expand Up @@ -206,7 +203,6 @@

Returns:
bool: `True` if the node can be quantized, otherwise `False`.

"""
# Check output
if quant_properties.quant_output is not None:
Expand Down Expand Up @@ -266,7 +262,6 @@

Returns:
Any: Resolved attribute on the module.

"""
targets = target_str.split(".")
for target in targets[:-1]:
Expand All @@ -279,7 +274,6 @@

Large scalars are skipped because ``torch.histc`` supports values only up
to a certain upper bound.

"""
HISTC_UPPER_BOUND = 3.4028235e15
if node.op == "get_attr" and isinstance(node.target, str):
Expand All @@ -297,7 +291,8 @@


def _is_non_float_tensor(node: Node) -> bool:
"""Check if the output of a node has a data type other than `torch.float32`.
"""Check if the output of a node has a data type other than
`torch.float32`.

If the output is not `torch.float32`, quantization cannot be performed, as
observers only work with floating-point tensors.
Expand All @@ -314,7 +309,6 @@
`torch.float32` as its data type.
- If node.meta["val"] is missing or is not an instance of `FakeTensor`,
the function returns True.

"""
if "val" in node.meta and isinstance(node.meta["val"], Sequence):
return any(
Expand Down Expand Up @@ -342,7 +336,6 @@
Raises:
RuntimeError: If the node is already annotated.
TypeError: If an input argument is not a ``Node`` instance.

"""
if is_annotated(node):
raise RuntimeError(
Expand Down Expand Up @@ -379,7 +372,6 @@
RuntimeError: If the node is already annotated.
ValueError: If ``mark_annotated`` is True, ``optional`` is True, or
``index`` is not zero.

"""
if is_annotated(node):
raise RuntimeError(
Expand Down Expand Up @@ -408,7 +400,6 @@
``pattern``. If ``filter_fn`` is provided, require all nodes in the chain
to pass the filter. Each pattern element is an iterable of disjunctive
node targets.

"""
if len(pattern) < 1:
raise ValueError("No pattern provided")
Expand Down Expand Up @@ -517,6 +508,9 @@
torch.ops.aten.squeeze_copy.default,
torch.ops.aten.squeeze_copy.dim,
torch.ops.aten.squeeze_.dim,
# DecomposeMaxPool1dPass emits squeeze_copy.dims as a view-like intermediate;
# include here so it receives SharedQuantizationSpec from its input.
torch.ops.aten.squeeze_copy.dims,
Comment thread
Ninja91 marked this conversation as resolved.
torch.ops.aten.squeeze.dim,
torch.ops.aten.squeeze.dims,
torch.ops.aten.unbind.int,
Expand Down Expand Up @@ -612,7 +606,6 @@
Returns:
_OpQuantProperties | None: Properties to apply, or ``None`` if the
node is unsupported or not suitable for quantization.

"""
if node.target == torch.ops.aten.conv_transpose2d.input:
weight_qspec = _adjust_weight_qspec_for_conv_transpose(
Expand Down Expand Up @@ -950,7 +943,6 @@

Returns:
Optional[List[List[Node]]]: Reserved for future use; currently None.

"""
for node in gm.graph.nodes:
if node.op != "call_function":
Expand Down
1 change: 0 additions & 1 deletion backends/arm/test/ops/test_max_pool1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def test_max_pool2d_tosa_FP_decomposed(test_data: Callable):


@common.parametrize("test_data", test_data_suite_all)
@pytest.mark.xfail(reason="MaxPool1D not yet supported", strict=False)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have expected U55 and U85 tests to pass now as all the independent operators of the decomposition are supported.

def test_max_pool2d_tosa_INT_decomposed(test_data: Callable):
"""Test max_pool1d with TOSA INT pipeline (quantized)."""
test_data, model_params = test_data()
Expand Down
Loading