-
Notifications
You must be signed in to change notification settings - Fork 1k
Add MaxPool1D decomposition pass support (#17022) #17022
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # Copyright 2025 Arm Limited and/or its affiliates. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from typing import List, Optional, Set, Type, Union | ||
|
|
||
| import torch | ||
| from executorch.backends.arm._passes.arm_pass import ArmPass | ||
| from executorch.exir.pass_base import ExportPass | ||
|
|
||
|
|
||
| def _normalize_to_list( | ||
| value: Optional[Union[int, List[int], tuple]], | ||
| default: Optional[List[int]] = None, | ||
| ) -> List[int]: | ||
| """Normalize parameter to list: handle None, int, tuple, list.""" | ||
| if value is None: | ||
| if default is None: | ||
| raise ValueError("Value cannot be None without a default") | ||
| return default | ||
| if isinstance(value, int): | ||
| return [value] | ||
| return list(value) | ||
|
|
||
|
|
||
| class DecomposeMaxPool1dPass(ArmPass): | ||
| """Decomposes max_pool1d into max_pool2d via unsqueeze_copy/squeeze_copy | ||
| operations. | ||
|
|
||
| This pass runs in transform_for_annotation (TFA) pipeline before quantization, | ||
| ensuring proper quantization annotation for the decomposed ops. | ||
|
|
||
| Transformation: | ||
| max_pool1d(x, kernel, stride, padding, dilation, ceil_mode) | ||
| → unsqueeze_copy(x, dim=2) # (N,C,L) → (N,C,1,L) | ||
| → max_pool2d(..., [1,k], [1,s], [0,p], [1,d], ceil_mode) | ||
| → squeeze_copy(..., dims=[2]) # (N,C,1,L') → (N,C,L') | ||
| """ | ||
|
Comment on lines
+29
to
+41
|
||
|
|
||
| _passes_required_after: Set[Type[ExportPass]] = set() | ||
|
|
||
| def call_operator(self, op, args, kwargs, meta): | ||
| if op != torch.ops.aten.max_pool1d.default or not self.allowed_to_transform( | ||
| meta | ||
| ): | ||
| return super().call_operator(op, args, kwargs, meta) | ||
|
Comment on lines
+45
to
+49
|
||
|
|
||
| # Extract and normalize arguments | ||
| x = args[0] | ||
| kernel_size = _normalize_to_list(args[1]) | ||
| stride = _normalize_to_list( | ||
| args[2] if len(args) > 2 else None, | ||
| default=kernel_size, # stride defaults to kernel_size | ||
| ) | ||
| padding = _normalize_to_list(args[3] if len(args) > 3 else 0) | ||
| dilation = _normalize_to_list(args[4] if len(args) > 4 else 1) | ||
| ceil_mode = args[5] if len(args) > 5 else False | ||
|
Comment on lines
+54
to
+60
|
||
|
|
||
| # Step 1: Unsqueeze input from 3D to 4D at dim=2 | ||
| # (N, C, L) → (N, C, 1, L) | ||
| unsqueeze_meta = meta.copy() | ||
| unsqueeze_meta.data["input_qparams"] = {} | ||
| unsqueeze_meta.data["output_qparams"] = {} | ||
| x_4d = super().call_operator( | ||
| torch.ops.aten.unsqueeze_copy.default, | ||
| (x, 2), | ||
| {}, | ||
| unsqueeze_meta, | ||
| updated=True, | ||
| ) | ||
|
Comment on lines
+67
to
+73
|
||
|
|
||
| # Step 2: Call max_pool2d with 2D parameters | ||
| # kernel: [k] → [1, k], stride: [s] → [1, s] | ||
| # padding: [p] → [0, p], dilation: [d] → [1, d] | ||
| pooled = super().call_operator( | ||
| torch.ops.aten.max_pool2d.default, | ||
| ( | ||
| x_4d, | ||
| [1] + kernel_size, | ||
| [1] + stride, | ||
| [0] + padding, | ||
| [1] + dilation, | ||
| ceil_mode, | ||
| ), | ||
| {}, | ||
| meta, | ||
| updated=True, | ||
| ) | ||
|
Comment on lines
+45
to
+91
|
||
|
|
||
| # Step 3: Squeeze output back to 3D at dims=[2] | ||
| # (N, C, 1, L') → (N, C, L') | ||
| squeeze_meta = meta.copy() | ||
| squeeze_meta.data["input_qparams"] = {} | ||
| squeeze_meta.data["output_qparams"] = {} | ||
| output = super().call_operator( | ||
| torch.ops.aten.squeeze_copy.dims, | ||
| (pooled, [2]), | ||
| {}, | ||
| squeeze_meta, | ||
| updated=True, | ||
| ) | ||
|
Comment on lines
+98
to
+104
|
||
|
|
||
| return output | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -88,7 +88,6 @@ def test_max_pool2d_tosa_FP_decomposed(test_data: Callable): | |
|
|
||
|
|
||
| @common.parametrize("test_data", test_data_suite_all) | ||
| @pytest.mark.xfail(reason="MaxPool1D not yet supported", strict=False) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would have expected U55 and U85 tests to pass now as all the independent operators of the decomposition are supported. |
||
| def test_max_pool2d_tosa_INT_decomposed(test_data: Callable): | ||
| """Test max_pool1d with TOSA INT pipeline (quantized).""" | ||
| test_data, model_params = test_data() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with this one. We shouldn't see any max_pool1d:s at this stage, the pass won't match any exir-targets and will emit torch.ops rather than exir_ops.