77
88import logging
99from collections import defaultdict
10- from collections .abc import Sequence
10+ from collections .abc import Callable , Sequence
1111from dataclasses import dataclass , field
12+ from typing import Any , cast
1213
1314from executorch .backends .arm ._passes import (
1415 AccumulateIndexPutPass ,
4950 DecomposeCumsumPass ,
5051 DecomposeDivPass ,
5152 DecomposeDivTensorModePass ,
53+ DecomposeDynamicAdaptiveAvgPool2dPass ,
5254 DecomposeDynamicFullPass ,
5355 DecomposeEinsumPass ,
5456 DecomposeEluPass ,
166168)
167169
168170from executorch .exir import ExportedProgram
169- from executorch .exir .pass_base import ExportPass
170- from executorch .exir .pass_manager import PassManager
171+ from executorch .exir ._program_utils import _get_updated_graph_signature
172+ from executorch .exir .pass_base import (
173+ ExportedProgramPassBase ,
174+ ExportedProgramPassResult ,
175+ ExportPass ,
176+ )
177+ from executorch .exir .pass_manager import ExportedProgramPassManager
171178from torch ._export .utils import _get_shape_env_from_gm
172179from torch .fx import GraphModule
173180from torch .fx .passes .infra .pass_base import PassResult
174- from torch .nn . modules import Module
181+ from torch .fx . passes . infra . pass_manager import PassManager as GraphModulePassManager
175182
176183logger = logging .getLogger (__name__ )
177184
@@ -187,6 +194,50 @@ class PassInsertions:
187194_registered_pass_insertions : dict [type , PassInsertions ] = {}
188195
189196
197+ def _graph_pass_name (graph_pass : Callable [[GraphModule ], PassResult | None ]) -> str :
198+ if isinstance (graph_pass , ExportPass ):
199+ return ArmPass .get_name (graph_pass )
200+ if hasattr (graph_pass , "__name__" ):
201+ return graph_pass .__name__
202+ return type (graph_pass ).__name__
203+
204+
205+ class _ExportedProgramGraphPassAdapter (ExportedProgramPassBase ):
206+ def __init__ (self , graph_pass : Callable [[GraphModule ], PassResult | None ]) -> None :
207+ self .graph_pass = graph_pass
208+
209+ def call (self , exported_program : ExportedProgram ) -> ExportedProgramPassResult :
210+ graph_pass = cast (Any , self .graph_pass )
211+ pass_exported_program = getattr (graph_pass , "exported_program" , None )
212+ if pass_exported_program is not None :
213+ # ExportedProgramPassManager works on a shallow copy; Arm graph
214+ # passes that store an ExportedProgram must update that copy.
215+ graph_pass .exported_program = exported_program
216+
217+ try :
218+ result = self .graph_pass (exported_program .graph_module )
219+ finally :
220+ if pass_exported_program is not None :
221+ graph_pass .exported_program = pass_exported_program
222+
223+ if result is None :
224+ raise TypeError (
225+ f"The result of pass { _graph_pass_name (self .graph_pass )} should be type PassResult."
226+ )
227+
228+ if result .modified :
229+ result .graph_module .recompile ()
230+ exported_program ._graph_module = result .graph_module
231+ exported_program ._graph_signature = _get_updated_graph_signature (
232+ exported_program .graph_signature ,
233+ result .graph_module ,
234+ )
235+ # Arm graph passes do not change symbolic shape constraints, and
236+ # metadata-only fake modes may differ after propagation.
237+
238+ return ExportedProgramPassResult (exported_program , result .modified )
239+
240+
190241def register_pass_insertions_before (
191242 target_pass_type : type , passes : list [ExportPass ]
192243) -> None :
@@ -210,7 +261,7 @@ def clear_registered_pass_insertions() -> None:
210261 _registered_pass_insertions .clear ()
211262
212263
213- class ArmPassManager (PassManager ):
264+ class ArmPassManager (ExportedProgramPassManager ):
214265 def __init__ (self , compile_spec : ArmCompileSpec ) -> None :
215266 self .compile_spec = compile_spec
216267 self .tosa_spec = compile_spec .tosa_spec
@@ -373,8 +424,39 @@ def _tosa_context(self, graph_module: GraphModule) -> TosaLoweringContext:
373424 shape_env = _get_shape_env_from_gm (graph_module )
374425 return TosaLoweringContext (self .tosa_spec , shape_env )
375426
376- def _transform (self , graph_module : GraphModule ):
377- return self (graph_module ).graph_module
427+ def _transform_graph_module (self , graph_module : GraphModule ):
428+ # TFA and control-flow submodule paths operate on bare GraphModules
429+ # without a standalone ExportedProgram to keep in sync.
430+ return GraphModulePassManager (self .passes )(graph_module ).graph_module
431+
432+ def __call__ ( # type: ignore[override]
433+ self ,
434+ module : ExportedProgram | GraphModule ,
435+ override_verifiers : Any | None = None ,
436+ ) -> ExportedProgramPassResult | PassResult :
437+ if isinstance (module , GraphModule ):
438+ if override_verifiers is not None :
439+ raise ValueError ("override_verifiers is only valid for ExportedProgram" )
440+ return GraphModulePassManager (self .passes )(module )
441+ return super ().__call__ (module , override_verifiers )
442+
443+ def _transform (
444+ self ,
445+ exported_program : ExportedProgram ,
446+ graph_module : GraphModule ,
447+ ) -> GraphModule :
448+ if graph_module is exported_program .graph_module :
449+ passes : list [
450+ ExportedProgramPassBase | Callable [[GraphModule ], PassResult | None ]
451+ ] = [_ExportedProgramGraphPassAdapter (p ) for p in self .passes ]
452+ transformed_program = ExportedProgramPassManager (passes )(
453+ exported_program
454+ ).exported_program
455+ exported_program ._graph_module = transformed_program .graph_module
456+ exported_program ._graph_signature = transformed_program .graph_signature
457+ exported_program ._range_constraints = transformed_program .range_constraints
458+ return exported_program .graph_module
459+ return self ._transform_graph_module (graph_module )
378460
379461 def add_pass (self , pipeline_pass ):
380462 if type (pipeline_pass ) in self ._skip_pass_types :
@@ -463,6 +545,7 @@ def _tosa_pipeline(
463545 AccumulateIndexPutPass (),
464546 DecomposeIndexTensorToGatherPass (),
465547 DecomposeAdaptiveAvgPool2dPass (),
548+ DecomposeDynamicAdaptiveAvgPool2dPass (),
466549 DecomposeAvgPool2dPass (),
467550 Conv1dUnsqueezePass (),
468551 ]
@@ -556,7 +639,7 @@ def _tosa_pipeline(
556639 self ._apply_pass_insertions ()
557640
558641 self .validate_constraints_mandatory ()
559- return self ._transform (graph_module )
642+ return self ._transform (exported_program , graph_module )
560643
561644 def transform_to_backend_pipeline (
562645 self , exported_program : ExportedProgram , graph_module : GraphModule
@@ -661,21 +744,4 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
661744 ]
662745 )
663746
664- return self ._transform (graph_module )
665-
666- def __call__ (self , module : Module ) -> PassResult :
667- try :
668- return super ().__call__ (module )
669- except Exception as e :
670- first_exception = e .__cause__ or e .__context__ or e
671- import re
672-
673- message = e .args [0 ]
674- m = re .search (r"An error occurred when running the '([^']+)' pass" , message )
675- if m :
676- pass_name = m .group (1 )
677- first_exception .args = (
678- f"{ pass_name } : { first_exception .args [0 ]} " ,
679- * first_exception .args [1 :],
680- )
681- raise first_exception
747+ return self ._transform_graph_module (graph_module )
0 commit comments