Skip to content

Commit 2495c4a

Browse files
authored
Merge branch 'main' into toupstream/arg_max
2 parents 1b97778 + fe9bc95 commit 2495c4a

48 files changed

Lines changed: 1555 additions & 189 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
name: Test Cortex-M ops
2+
3+
permissions:
4+
id-token: write
5+
contents: read
6+
7+
on:
8+
workflow_call:
9+
inputs:
10+
targets:
11+
description: 'JSON array of cortex-m target CPUs to run the op tests against, e.g. ["cortex-m7", "cortex-m0plus"]'
12+
required: true
13+
type: string
14+
timeout:
15+
description: 'Per-matrix-entry timeout in minutes'
16+
required: false
17+
type: number
18+
default: 120
19+
20+
jobs:
21+
run:
22+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
23+
strategy:
24+
matrix:
25+
target: ${{ fromJSON(inputs.targets) }}
26+
fail-fast: false
27+
with:
28+
job-name: cortex-m-ops-${{ matrix.target }}
29+
runner: linux.2xlarge.memory
30+
docker-image: ci-image:executorch-ubuntu-22.04-arm-sdk
31+
submodules: 'recursive'
32+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
33+
timeout: ${{ inputs.timeout }}
34+
script: |
35+
# The generic Linux job chooses to use base env, not the one setup by the image
36+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
37+
conda activate "${CONDA_ENV}"
38+
39+
source .ci/scripts/utils.sh
40+
install_executorch "--use-pt-pinned-commit"
41+
42+
# Install arm dependencies
43+
.ci/scripts/setup-arm-baremetal-tools.sh
44+
source examples/arm/arm-scratch/setup_path.sh
45+
46+
# Build the runner for this target (written to a target-suffixed dir
47+
# that the op tests resolve from via --cortex-m-target below).
48+
backends/cortex_m/test/build_test_runner.sh --target=${{ matrix.target }}
49+
50+
# Run the op suite against this target: dialect tests check the lowered
51+
# op set, implementation tests check FVP numerics. Both are parametrized
52+
# over --cortex-m-target, so a future target-dependent lowering change is
53+
# caught here. (cortex-m55 runs on pull via the full-suite job.)
54+
pytest --config-file=backends/arm/test/pytest.ini \
55+
backends/cortex_m/test/ops \
56+
--cortex-m-target=${{ matrix.target }}

.github/workflows/trunk.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,3 +1076,12 @@ jobs:
10761076
with:
10771077
models: '["mv2", "mv3"]'
10781078
targets: '["cortex-m55", "cortex-m7", "cortex-m0plus"]'
1079+
1080+
test-cortex-m-ops:
1081+
name: test-cortex-m-ops
1082+
permissions:
1083+
id-token: write
1084+
contents: read
1085+
uses: ./.github/workflows/_test_cortex_m_ops.yml
1086+
with:
1087+
targets: '["cortex-m7", "cortex-m0plus"]'

.lintrunner.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ exclude_patterns = [
195195
# Kernel areas to onboard separately.
196196
'kernels/optimized/**',
197197
'kernels/portable/**',
198-
'kernels/quantized/**',
199198
'kernels/test/**',
200199

201200
# Runtime areas to onboard incrementally.
@@ -229,6 +228,12 @@ command = [
229228
'--extra-arg=--suppress=unknownMacro:*kernels/prim_ops/*',
230229
'--extra-arg=--suppress=syntaxError:*kernels/prim_ops/*',
231230
'--extra-arg=--suppress=unusedFunction:*kernels/prim_ops/*',
231+
# Quantized kernels have NEON-gated code and registration helpers that
232+
# cppcheck cannot see in every configuration.
233+
'--extra-arg=--suppress=unreadVariable:*kernels/quantized/*',
234+
'--extra-arg=--suppress=unusedFunction:*kernels/quantized/*',
235+
'--extra-arg=--suppress=constParameterReference:*kernels/quantized/*',
236+
'--extra-arg=--suppress=suspiciousFloatingPointCast:*kernels/quantized/*',
232237
'--',
233238
'@{{PATHSFILE}}'
234239
]

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@
4343
from .decompose_cumsum_pass import DecomposeCumsumPass # noqa
4444
from .decompose_div_pass import DecomposeDivPass # noqa
4545
from .decompose_div_tensor_mode import DecomposeDivTensorModePass # noqa
46+
from .decompose_dynamic_adaptive_avg_pool2d_pass import ( # noqa
47+
DecomposeDynamicAdaptiveAvgPool2dPass,
48+
)
4649
from .decompose_dynamic_full_pass import DecomposeDynamicFullPass # noqa
4750
from .decompose_einsum_pass import DecomposeEinsumPass # noqa
4851
from .decompose_elu_pass import ConvertEluFamilyToEluPass, DecomposeEluPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 92 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
import logging
99
from collections import defaultdict
10-
from collections.abc import Sequence
10+
from collections.abc import Callable, Sequence
1111
from dataclasses import dataclass, field
12+
from typing import Any, cast
1213

1314
from executorch.backends.arm._passes import (
1415
AccumulateIndexPutPass,
@@ -49,6 +50,7 @@
4950
DecomposeCumsumPass,
5051
DecomposeDivPass,
5152
DecomposeDivTensorModePass,
53+
DecomposeDynamicAdaptiveAvgPool2dPass,
5254
DecomposeDynamicFullPass,
5355
DecomposeEinsumPass,
5456
DecomposeEluPass,
@@ -166,12 +168,17 @@
166168
)
167169

168170
from executorch.exir import ExportedProgram
169-
from executorch.exir.pass_base import ExportPass
170-
from executorch.exir.pass_manager import PassManager
171+
from executorch.exir._program_utils import _get_updated_graph_signature
172+
from executorch.exir.pass_base import (
173+
ExportedProgramPassBase,
174+
ExportedProgramPassResult,
175+
ExportPass,
176+
)
177+
from executorch.exir.pass_manager import ExportedProgramPassManager
171178
from torch._export.utils import _get_shape_env_from_gm
172179
from torch.fx import GraphModule
173180
from torch.fx.passes.infra.pass_base import PassResult
174-
from torch.nn.modules import Module
181+
from torch.fx.passes.infra.pass_manager import PassManager as GraphModulePassManager
175182

176183
logger = logging.getLogger(__name__)
177184

@@ -187,6 +194,50 @@ class PassInsertions:
187194
_registered_pass_insertions: dict[type, PassInsertions] = {}
188195

189196

197+
def _graph_pass_name(graph_pass: Callable[[GraphModule], PassResult | None]) -> str:
198+
if isinstance(graph_pass, ExportPass):
199+
return ArmPass.get_name(graph_pass)
200+
if hasattr(graph_pass, "__name__"):
201+
return graph_pass.__name__
202+
return type(graph_pass).__name__
203+
204+
205+
class _ExportedProgramGraphPassAdapter(ExportedProgramPassBase):
206+
def __init__(self, graph_pass: Callable[[GraphModule], PassResult | None]) -> None:
207+
self.graph_pass = graph_pass
208+
209+
def call(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
210+
graph_pass = cast(Any, self.graph_pass)
211+
pass_exported_program = getattr(graph_pass, "exported_program", None)
212+
if pass_exported_program is not None:
213+
# ExportedProgramPassManager works on a shallow copy; Arm graph
214+
# passes that store an ExportedProgram must update that copy.
215+
graph_pass.exported_program = exported_program
216+
217+
try:
218+
result = self.graph_pass(exported_program.graph_module)
219+
finally:
220+
if pass_exported_program is not None:
221+
graph_pass.exported_program = pass_exported_program
222+
223+
if result is None:
224+
raise TypeError(
225+
f"The result of pass {_graph_pass_name(self.graph_pass)} should be type PassResult."
226+
)
227+
228+
if result.modified:
229+
result.graph_module.recompile()
230+
exported_program._graph_module = result.graph_module
231+
exported_program._graph_signature = _get_updated_graph_signature(
232+
exported_program.graph_signature,
233+
result.graph_module,
234+
)
235+
# Arm graph passes do not change symbolic shape constraints, and
236+
# metadata-only fake modes may differ after propagation.
237+
238+
return ExportedProgramPassResult(exported_program, result.modified)
239+
240+
190241
def register_pass_insertions_before(
191242
target_pass_type: type, passes: list[ExportPass]
192243
) -> None:
@@ -210,7 +261,7 @@ def clear_registered_pass_insertions() -> None:
210261
_registered_pass_insertions.clear()
211262

212263

213-
class ArmPassManager(PassManager):
264+
class ArmPassManager(ExportedProgramPassManager):
214265
def __init__(self, compile_spec: ArmCompileSpec) -> None:
215266
self.compile_spec = compile_spec
216267
self.tosa_spec = compile_spec.tosa_spec
@@ -373,8 +424,39 @@ def _tosa_context(self, graph_module: GraphModule) -> TosaLoweringContext:
373424
shape_env = _get_shape_env_from_gm(graph_module)
374425
return TosaLoweringContext(self.tosa_spec, shape_env)
375426

376-
def _transform(self, graph_module: GraphModule):
377-
return self(graph_module).graph_module
427+
def _transform_graph_module(self, graph_module: GraphModule):
428+
# TFA and control-flow submodule paths operate on bare GraphModules
429+
# without a standalone ExportedProgram to keep in sync.
430+
return GraphModulePassManager(self.passes)(graph_module).graph_module
431+
432+
def __call__( # type: ignore[override]
433+
self,
434+
module: ExportedProgram | GraphModule,
435+
override_verifiers: Any | None = None,
436+
) -> ExportedProgramPassResult | PassResult:
437+
if isinstance(module, GraphModule):
438+
if override_verifiers is not None:
439+
raise ValueError("override_verifiers is only valid for ExportedProgram")
440+
return GraphModulePassManager(self.passes)(module)
441+
return super().__call__(module, override_verifiers)
442+
443+
def _transform(
444+
self,
445+
exported_program: ExportedProgram,
446+
graph_module: GraphModule,
447+
) -> GraphModule:
448+
if graph_module is exported_program.graph_module:
449+
passes: list[
450+
ExportedProgramPassBase | Callable[[GraphModule], PassResult | None]
451+
] = [_ExportedProgramGraphPassAdapter(p) for p in self.passes]
452+
transformed_program = ExportedProgramPassManager(passes)(
453+
exported_program
454+
).exported_program
455+
exported_program._graph_module = transformed_program.graph_module
456+
exported_program._graph_signature = transformed_program.graph_signature
457+
exported_program._range_constraints = transformed_program.range_constraints
458+
return exported_program.graph_module
459+
return self._transform_graph_module(graph_module)
378460

379461
def add_pass(self, pipeline_pass):
380462
if type(pipeline_pass) in self._skip_pass_types:
@@ -463,6 +545,7 @@ def _tosa_pipeline(
463545
AccumulateIndexPutPass(),
464546
DecomposeIndexTensorToGatherPass(),
465547
DecomposeAdaptiveAvgPool2dPass(),
548+
DecomposeDynamicAdaptiveAvgPool2dPass(),
466549
DecomposeAvgPool2dPass(),
467550
Conv1dUnsqueezePass(),
468551
]
@@ -556,7 +639,7 @@ def _tosa_pipeline(
556639
self._apply_pass_insertions()
557640

558641
self.validate_constraints_mandatory()
559-
return self._transform(graph_module)
642+
return self._transform(exported_program, graph_module)
560643

561644
def transform_to_backend_pipeline(
562645
self, exported_program: ExportedProgram, graph_module: GraphModule
@@ -661,21 +744,4 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
661744
]
662745
)
663746

664-
return self._transform(graph_module)
665-
666-
def __call__(self, module: Module) -> PassResult:
667-
try:
668-
return super().__call__(module)
669-
except Exception as e:
670-
first_exception = e.__cause__ or e.__context__ or e
671-
import re
672-
673-
message = e.args[0]
674-
m = re.search(r"An error occurred when running the '([^']+)' pass", message)
675-
if m:
676-
pass_name = m.group(1)
677-
first_exception.args = (
678-
f"{pass_name}: {first_exception.args[0]}",
679-
*first_exception.args[1:],
680-
)
681-
raise first_exception
747+
return self._transform_graph_module(graph_module)

0 commit comments

Comments
 (0)