2020class FuseCascadedTransposeOrPermuteOps (RemoveOrReplacePassInterface ):
2121 """
2222 Fuse a chain of transpose and permute ops into a single permute or a no-op.
23- Handles branches and chains permutes.
23+ Handles branches and chains of permutes, including permute-view-permute
24+ patterns where a squeeze/unsqueeze view sits between two permutes.
2425 """
2526
2627 transpose_or_permute_target = {
2728 exir_ops .edge .aten .transpose_copy .int ,
2829 exir_ops .edge .aten .permute_copy .default ,
2930 }
3031
32+ _VIEW_OPS = {
33+ exir_ops .edge .aten .view_copy .default ,
34+ exir_ops .edge .aten .view .default ,
35+ }
36+
3137 @property
3238 def targets (self ) -> list [EdgeOpOverload ]:
3339 return list (self .transpose_or_permute_target )
3440
3541 def maybe_remove_or_replace (self , node : Node ) -> bool :
36- # Fuse with the parent node if it's also a permute or a transpose. Since the
37- # pass interface traverses all ops in order the pass will properly fuse a chain
38- # of permutes.
3942 parent_node = get_arg (node , "input" , Node )
40- if parent_node .target not in self .transpose_or_permute_target :
41- return False
42- input_of_parent = get_arg (parent_node , "input" , Node )
4343
44- # Compute combined effect of permutes.
44+ # Case 1: Direct permute/transpose → permute/transpose
45+ if parent_node .target in self .transpose_or_permute_target :
46+ return self ._fuse_direct (node , parent_node )
47+
48+ # Case 2: permute → view_copy(squeeze/unsqueeze) → permute
49+ if parent_node .target in self ._VIEW_OPS :
50+ return self ._fuse_across_view (node , parent_node )
51+
52+ return False
53+
54+ def _fuse_direct (self , node : Node , parent_node : Node ) -> bool :
55+ """Fuse two adjacent permute/transpose ops."""
56+ input_of_parent = get_arg (parent_node , "input" , Node )
4557 dims = list (range (node .meta ["val" ].ndim ))
4658
4759 if parent_node .target == exir_ops .edge .aten .transpose_copy .int :
@@ -54,7 +66,6 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
5466 else :
5567 dims = get_permuted_dims (node , dims )
5668
57- # If combined effect is identity replace the node with input.
5869 if dims == sorted (dims ):
5970 node .replace_all_uses_with (input_of_parent )
6071 else :
@@ -67,3 +78,104 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
6778 node .replace_all_uses_with (new_permute )
6879
6980 return True
81+
82+ def _apply_view_to_dims (
83+ self , dims : list [int ], view_in_shape , view_out_shape
84+ ) -> list [int ] | None :
85+ """Apply a squeeze or unsqueeze view to dimension mapping.
86+
87+ Returns the updated dims, or None if the view cannot be mapped.
88+ """
89+ if len (view_out_shape ) == len (view_in_shape ) + 1 :
90+ # unsqueeze: insert a new dim
91+ index = self ._find_extra_one (view_out_shape , view_in_shape )
92+ if index == - 1 :
93+ return None
94+ dims = [x + 1 if x >= index else x for x in dims ]
95+ dims .insert (index , - 1 ) # -1 marks the inserted dim
96+ elif len (view_in_shape ) == len (view_out_shape ) + 1 :
97+ # squeeze: remove a dim
98+ index = self ._find_extra_one (view_in_shape , view_out_shape )
99+ if index == - 1 :
100+ return None
101+ dims = list (dims )
102+ del dims [index ]
103+ return dims
104+
105+ def _fuse_across_view (self , node : Node , view_node : Node ) -> bool :
106+ """Fuse permute -> view(squeeze/unsqueeze) -> permute into a view_copy."""
107+ # view_node must have exactly one user (this permute node)
108+ if len (view_node .users ) != 1 :
109+ return False
110+ # view_node's parent must be a permute/transpose
111+ view_input = get_arg (view_node , "input" , Node )
112+ if view_input .target not in self .transpose_or_permute_target :
113+ return False
114+ # The view must be a squeeze or unsqueeze (rank differs by 1)
115+ view_in_shape = view_input .meta ["val" ].shape
116+ view_out_shape = view_node .meta ["val" ].shape
117+ if abs (len (view_in_shape ) - len (view_out_shape )) != 1 :
118+ return False
119+
120+ # Get the input before the first permute
121+ input_of_first_permute = get_arg (view_input , "input" , Node )
122+
123+ # Compute the combined effect on the original input dimensions
124+ # Start with identity dims for the original input
125+ original_ndim = input_of_first_permute .meta ["val" ].ndim
126+ dims = list (range (original_ndim ))
127+
128+ # Apply first permute
129+ if view_input .target == exir_ops .edge .aten .transpose_copy .int :
130+ dims = get_transposed_dims (view_input , dims )
131+ else :
132+ dims = get_permuted_dims (view_input , dims )
133+
134+ # Apply the view (squeeze/unsqueeze)
135+ dims = self ._apply_view_to_dims (dims , view_in_shape , view_out_shape )
136+ if dims is None :
137+ return False
138+
139+ # Apply second permute (node)
140+ if node .target == exir_ops .edge .aten .transpose_copy .int :
141+ node_dims = list (range (len (dims )))
142+ node_dims = get_transposed_dims (node , node_dims )
143+ dims = [dims [d ] for d in node_dims ]
144+ elif node .target == exir_ops .edge .aten .permute_copy .default :
145+ perm = get_arg (node , "dims" )
146+ dims = [dims [d ] for d in perm ]
147+ else :
148+ raise ValueError (f"Unexpected target: { node .target } " )
149+
150+ # Check if the combined effect (ignoring -1 inserted dims) is identity
151+ real_dims = [d for d in dims if d != - 1 ]
152+
153+ if real_dims == sorted (real_dims ):
154+ # Combined permutations are identity — replace with view_copy
155+ # (the only remaining effect is the squeeze/unsqueeze reshape)
156+ output_shape = node .meta ["val" ].shape
157+ if output_shape == input_of_first_permute .meta ["val" ].shape :
158+ # Total no-op: replace with input
159+ node .replace_all_uses_with (input_of_first_permute )
160+ else :
161+ with node .graph .inserting_before (node ):
162+ new_view = node .graph .call_function (
163+ exir_ops .edge .aten .view_copy .default ,
164+ args = (input_of_first_permute , list (output_shape )),
165+ )
166+ new_view .meta = node .meta
167+ node .replace_all_uses_with (new_view )
168+ return True
169+
170+ return False
171+
172+ @staticmethod
173+ def _find_extra_one (longer : list [int ], shorter : list [int ]) -> int :
174+ if len (longer ) != len (shorter ) + 1 :
175+ return - 1
176+ for i in range (len (shorter )):
177+ if longer [i ] != shorter [i ]:
178+ if longer [i ] == 1 and shorter [i :] == longer [i + 1 :]:
179+ return i
180+ return - 1
181+ return len (shorter ) if longer [- 1 ] == 1 else - 1
0 commit comments