2020class FuseCascadedTransposeOrPermuteOps (RemoveOrReplacePassInterface ):
2121 """
2222 Fuse a chain of transpose and permute ops into a single permute or a no-op.
23- Handles branches and chains permutes.
23+ Handles branches and chains of permutes, including permute-view-permute
24+ patterns where a squeeze/unsqueeze view sits between two permutes.
2425 """
2526
2627 transpose_or_permute_target = {
2728 exir_ops .edge .aten .transpose_copy .int ,
2829 exir_ops .edge .aten .permute_copy .default ,
2930 }
3031
32+ _VIEW_OPS = {
33+ exir_ops .edge .aten .view_copy .default ,
34+ exir_ops .edge .aten .view .default ,
35+ }
36+
3137 @property
3238 def targets (self ) -> list [EdgeOpOverload ]:
3339 return list (self .transpose_or_permute_target )
3440
3541 def maybe_remove_or_replace (self , node : Node ) -> bool :
36- # Fuse with the parent node if it's also a permute or a transpose. Since the
37- # pass interface traverses all ops in order the pass will properly fuse a chain
38- # of permutes.
3942 parent_node = get_arg (node , "input" , Node )
40- if parent_node .target not in self .transpose_or_permute_target :
41- return False
42- input_of_parent = get_arg (parent_node , "input" , Node )
4343
44- # Compute combined effect of permutes.
44+ # Case 1: Direct permute/transpose → permute/transpose
45+ if parent_node .target in self .transpose_or_permute_target :
46+ return self ._fuse_direct (node , parent_node )
47+
48+ # Case 2: permute → view_copy(squeeze/unsqueeze) → permute
49+ if parent_node .target in self ._VIEW_OPS :
50+ return self ._fuse_across_view (node , parent_node )
51+
52+ return False
53+
54+ def _fuse_direct (self , node : Node , parent_node : Node ) -> bool :
55+ """Fuse two adjacent permute/transpose ops."""
56+ input_of_parent = get_arg (parent_node , "input" , Node )
4557 dims = list (range (node .meta ["val" ].ndim ))
4658
4759 if parent_node .target == exir_ops .edge .aten .transpose_copy .int :
@@ -54,7 +66,6 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
5466 else :
5567 dims = get_permuted_dims (node , dims )
5668
57- # If combined effect is identity replace the node with input.
5869 if dims == sorted (dims ):
5970 node .replace_all_uses_with (input_of_parent )
6071 else :
@@ -67,3 +78,91 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
6778 node .replace_all_uses_with (new_permute )
6879
6980 return True
81+
82+ def _fuse_across_view (self , node : Node , view_node : Node ) -> bool :
83+ """Fuse permute -> view(squeeze/unsqueeze) -> permute into a view_copy."""
84+ # view_node must have exactly one user (this permute node)
85+ if len (view_node .users ) != 1 :
86+ return False
87+ # view_node's parent must be a permute/transpose
88+ view_input = get_arg (view_node , "input" , Node )
89+ if view_input .target not in self .transpose_or_permute_target :
90+ return False
91+ # The view must be a squeeze or unsqueeze (rank differs by 1)
92+ view_in_shape = view_input .meta ["val" ].shape
93+ view_out_shape = view_node .meta ["val" ].shape
94+ if abs (len (view_in_shape ) - len (view_out_shape )) != 1 :
95+ return False
96+
97+ # Get the input before the first permute
98+ input_of_first_permute = get_arg (view_input , "input" , Node )
99+
100+ # Compute the combined effect on the original input dimensions
101+ # Start with identity dims for the original input
102+ original_ndim = input_of_first_permute .meta ["val" ].ndim
103+ dims = list (range (original_ndim ))
104+
105+ # Apply first permute
106+ if view_input .target == exir_ops .edge .aten .transpose_copy .int :
107+ dims = get_transposed_dims (view_input , dims )
108+ else :
109+ dims = get_permuted_dims (view_input , dims )
110+
111+ # Apply the view (squeeze/unsqueeze)
112+ if len (view_out_shape ) == len (view_in_shape ) + 1 :
113+ # unsqueeze: insert a new dim
114+ index = self ._find_extra_one (view_out_shape , view_in_shape )
115+ if index == - 1 :
116+ return False
117+ dims = [x + 1 if x >= index else x for x in dims ]
118+ dims .insert (index , - 1 ) # -1 marks the inserted dim
119+ elif len (view_in_shape ) == len (view_out_shape ) + 1 :
120+ # squeeze: remove a dim
121+ index = self ._find_extra_one (view_in_shape , view_out_shape )
122+ if index == - 1 :
123+ return False
124+ del dims [index ]
125+
126+ # Apply second permute (node)
127+ if node .target == exir_ops .edge .aten .transpose_copy .int :
128+ node_dims = list (range (len (dims )))
129+ node_dims = get_transposed_dims (node , node_dims )
130+ dims = [dims [d ] for d in node_dims ]
131+ elif node .target == exir_ops .edge .aten .permute_copy .default :
132+ perm = get_arg (node , "dims" )
133+ dims = [dims [d ] for d in perm ]
134+ else :
135+ raise ValueError (f"Unexpected target: { node .target } " )
136+
137+ # Check if the combined effect (ignoring -1 inserted dims) is identity
138+ real_dims = [d for d in dims if d != - 1 ]
139+
140+ if real_dims == sorted (real_dims ):
141+ # Combined permutations are identity — replace with view_copy
142+ # (the only remaining effect is the squeeze/unsqueeze reshape)
143+ output_shape = node .meta ["val" ].shape
144+ if output_shape == input_of_first_permute .meta ["val" ].shape :
145+ # Total no-op: replace with input
146+ node .replace_all_uses_with (input_of_first_permute )
147+ else :
148+ with node .graph .inserting_before (node ):
149+ new_view = node .graph .call_function (
150+ exir_ops .edge .aten .view_copy .default ,
151+ args = (input_of_first_permute , list (output_shape )),
152+ )
153+ new_view .meta = node .meta
154+ node .replace_all_uses_with (new_view )
155+ return True
156+
157+ return False
158+
159+ @staticmethod
160+ def _find_extra_one (longer : list [int ], shorter : list [int ]) -> int :
161+ if len (longer ) != len (shorter ) + 1 :
162+ return - 1
163+ for i in range (len (shorter )):
164+ if longer [i ] != shorter [i ]:
165+ if longer [i ] == 1 and shorter [i :] == longer [i + 1 :]:
166+ return i
167+ return - 1
168+ return len (shorter ) if longer [- 1 ] == 1 else - 1
0 commit comments