Skip to content

Commit 67185b6

Browse files
authored
Merge branch 'main' into export-D105741356
2 parents 86e6284 + 6f052fe commit 67185b6

3 files changed

Lines changed: 358 additions & 111 deletions

File tree

backends/transforms/fuse_cascaded_transpose_or_permute_ops.py

Lines changed: 121 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,40 @@
2020
class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface):
2121
"""
2222
Fuse a chain of transpose and permute ops into a single permute or a no-op.
23-
Handles branches and chains permutes.
23+
Handles branches and chains of permutes, including permute-view-permute
24+
patterns where a squeeze/unsqueeze view sits between two permutes.
2425
"""
2526

2627
transpose_or_permute_target = {
2728
exir_ops.edge.aten.transpose_copy.int,
2829
exir_ops.edge.aten.permute_copy.default,
2930
}
3031

32+
_VIEW_OPS = {
33+
exir_ops.edge.aten.view_copy.default,
34+
exir_ops.edge.aten.view.default,
35+
}
36+
3137
@property
3238
def targets(self) -> list[EdgeOpOverload]:
3339
return list(self.transpose_or_permute_target)
3440

3541
def maybe_remove_or_replace(self, node: Node) -> bool:
36-
# Fuse with the parent node if it's also a permute or a transpose. Since the
37-
# pass interface traverses all ops in order the pass will properly fuse a chain
38-
# of permutes.
3942
parent_node = get_arg(node, "input", Node)
40-
if parent_node.target not in self.transpose_or_permute_target:
41-
return False
42-
input_of_parent = get_arg(parent_node, "input", Node)
4343

44-
# Compute combined effect of permutes.
44+
# Case 1: Direct permute/transpose → permute/transpose
45+
if parent_node.target in self.transpose_or_permute_target:
46+
return self._fuse_direct(node, parent_node)
47+
48+
# Case 2: permute → view_copy(squeeze/unsqueeze) → permute
49+
if parent_node.target in self._VIEW_OPS:
50+
return self._fuse_across_view(node, parent_node)
51+
52+
return False
53+
54+
def _fuse_direct(self, node: Node, parent_node: Node) -> bool:
55+
"""Fuse two adjacent permute/transpose ops."""
56+
input_of_parent = get_arg(parent_node, "input", Node)
4557
dims = list(range(node.meta["val"].ndim))
4658

4759
if parent_node.target == exir_ops.edge.aten.transpose_copy.int:
@@ -54,7 +66,6 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
5466
else:
5567
dims = get_permuted_dims(node, dims)
5668

57-
# If combined effect is identity replace the node with input.
5869
if dims == sorted(dims):
5970
node.replace_all_uses_with(input_of_parent)
6071
else:
@@ -67,3 +78,104 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
6778
node.replace_all_uses_with(new_permute)
6879

6980
return True
81+
82+
def _apply_view_to_dims(
83+
self, dims: list[int], view_in_shape, view_out_shape
84+
) -> list[int] | None:
85+
"""Apply a squeeze or unsqueeze view to dimension mapping.
86+
87+
Returns the updated dims, or None if the view cannot be mapped.
88+
"""
89+
if len(view_out_shape) == len(view_in_shape) + 1:
90+
# unsqueeze: insert a new dim
91+
index = self._find_extra_one(view_out_shape, view_in_shape)
92+
if index == -1:
93+
return None
94+
dims = [x + 1 if x >= index else x for x in dims]
95+
dims.insert(index, -1) # -1 marks the inserted dim
96+
elif len(view_in_shape) == len(view_out_shape) + 1:
97+
# squeeze: remove a dim
98+
index = self._find_extra_one(view_in_shape, view_out_shape)
99+
if index == -1:
100+
return None
101+
dims = list(dims)
102+
del dims[index]
103+
return dims
104+
105+
def _fuse_across_view(self, node: Node, view_node: Node) -> bool: # noqa: C901
106+
"""Fuse permute -> view(squeeze/unsqueeze) -> permute into a view_copy."""
107+
# view_node must have exactly one user (this permute node)
108+
if len(view_node.users) != 1:
109+
return False
110+
# view_node's parent must be a permute/transpose
111+
view_input = get_arg(view_node, "input", Node)
112+
if view_input.target not in self.transpose_or_permute_target:
113+
return False
114+
# The view must be a squeeze or unsqueeze (rank differs by 1)
115+
view_in_shape = view_input.meta["val"].shape
116+
view_out_shape = view_node.meta["val"].shape
117+
if abs(len(view_in_shape) - len(view_out_shape)) != 1:
118+
return False
119+
120+
# Get the input before the first permute
121+
input_of_first_permute = get_arg(view_input, "input", Node)
122+
123+
# Compute the combined effect on the original input dimensions
124+
# Start with identity dims for the original input
125+
original_ndim = input_of_first_permute.meta["val"].ndim
126+
dims = list(range(original_ndim))
127+
128+
# Apply first permute
129+
if view_input.target == exir_ops.edge.aten.transpose_copy.int:
130+
dims = get_transposed_dims(view_input, dims)
131+
else:
132+
dims = get_permuted_dims(view_input, dims)
133+
134+
# Apply the view (squeeze/unsqueeze)
135+
dims = self._apply_view_to_dims(dims, view_in_shape, view_out_shape)
136+
if dims is None:
137+
return False
138+
139+
# Apply second permute (node)
140+
if node.target == exir_ops.edge.aten.transpose_copy.int:
141+
node_dims = list(range(len(dims)))
142+
node_dims = get_transposed_dims(node, node_dims)
143+
dims = [dims[d] for d in node_dims]
144+
elif node.target == exir_ops.edge.aten.permute_copy.default:
145+
perm = get_arg(node, "dims")
146+
dims = [dims[d] for d in perm]
147+
else:
148+
raise ValueError(f"Unexpected target: {node.target}")
149+
150+
# Check if the combined effect (ignoring -1 inserted dims) is identity
151+
real_dims = [d for d in dims if d != -1]
152+
153+
if real_dims == sorted(real_dims):
154+
# Combined permutations are identity — replace with view_copy
155+
# (the only remaining effect is the squeeze/unsqueeze reshape)
156+
output_shape = node.meta["val"].shape
157+
if output_shape == input_of_first_permute.meta["val"].shape:
158+
# Total no-op: replace with input
159+
node.replace_all_uses_with(input_of_first_permute)
160+
else:
161+
with node.graph.inserting_before(node):
162+
new_view = node.graph.call_function(
163+
exir_ops.edge.aten.view_copy.default,
164+
args=(input_of_first_permute, list(output_shape)),
165+
)
166+
new_view.meta = node.meta
167+
node.replace_all_uses_with(new_view)
168+
return True
169+
170+
return False
171+
172+
@staticmethod
173+
def _find_extra_one(longer: list[int], shorter: list[int]) -> int:
174+
if len(longer) != len(shorter) + 1:
175+
return -1
176+
for i in range(len(shorter)):
177+
if longer[i] != shorter[i]:
178+
if longer[i] == 1 and shorter[i:] == longer[i + 1 :]:
179+
return i
180+
return -1
181+
return len(shorter) if longer[-1] == 1 else -1

backends/transforms/remove_permutes_around_elementwise_ops.py

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import torch
1414
import torch.fx
15-
from executorch.backends.transforms.permute_pass_utils import get_arg
15+
from executorch.backends.transforms.permute_pass_utils import get_arg, set_arg
1616
from executorch.exir.dialects._ops import ops as exir_ops
1717
from executorch.exir.pass_base import ExportPass, PassResult
1818

@@ -106,11 +106,8 @@ def _check_squeeze_unsqueeze_view(self, node: torch.fx.Node) -> bool:
106106
return True
107107
if node.target not in self._VIEW_OPS:
108108
return False
109-
if node.meta.get("val") is None:
110-
return False
111109
inp = node.args[0]
112-
if not isinstance(inp, torch.fx.Node) or inp.meta.get("val") is None:
113-
return False
110+
assert isinstance(inp, torch.fx.Node)
114111
in_shape = inp.meta["val"].shape
115112
out_shape = node.meta["val"].shape
116113
if len(out_shape) == len(in_shape) + 1:
@@ -377,9 +374,9 @@ def _is_constant(self, node: torch.fx.Node) -> bool:
377374
def _get_node_rank(self, node: torch.fx.Node) -> int | None:
378375
"""Return the tensor rank of a node's output, or None if unknown."""
379376
val = node.meta.get("val")
380-
if val is not None and hasattr(val, "shape"):
381-
return len(val.shape)
382-
return None
377+
if val is None:
378+
return None
379+
return len(val.shape)
383380

384381
@staticmethod
385382
def _is_pointwise(target) -> bool:
@@ -432,10 +429,8 @@ def permute_subgraph(self, subgraph: Subgraph) -> None: # noqa: C901
432429
node.update_arg(1, index)
433430
elif node.target in self._SQUEEZE_OPS:
434431
# squeeze dim is in input space (rank)
435-
dim = cast(int, node.args[1])
436-
rank = len(node_start_perm)
437-
index = dim if dim >= 0 else dim + rank
438-
node.update_arg(1, node_start_perm[index])
432+
dim = get_arg(node, "dim", int)
433+
set_arg(node, "dim", node_start_perm[dim])
439434

440435
# Skip incoming permutes.
441436
for inp, out in subgraph.edges_in:
@@ -486,42 +481,25 @@ def permute_subgraph(self, subgraph: Subgraph) -> None: # noqa: C901
486481
out.replace_all_uses_with(inp)
487482

488483
def update_cat(self, node: torch.fx.Node, start_permute: list[int]) -> None:
489-
if len(node.args) >= 2:
490-
node.update_arg(1, start_permute[cast(int, node.args[1])])
491-
elif "dim" in node.kwargs:
492-
node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])])
493-
else:
494-
# Default cat dim is 0.
495-
node.update_kwarg("dim", start_permute[0])
484+
dim = get_arg(node, "dim", int)
485+
set_arg(node, "dim", start_permute[dim])
496486

497487
def update_mean_dim(self, node: torch.fx.Node, start_permute: list[int]) -> None:
498-
if len(node.args) >= 2:
499-
node.update_arg(
500-
1, [start_permute[dim] for dim in cast(list[int], node.args[1])]
501-
)
502-
else:
503-
node.update_kwarg(
504-
"dim",
505-
[start_permute[dim] for dim in cast(list[int], node.kwargs["dim"])],
506-
)
488+
dims = get_arg(node, "dim")
489+
set_arg(node, "dim", [start_permute[d] for d in cast(list[int], dims)])
507490

508491
def update_slice_copy(self, node: torch.fx.Node, start_permute: list[int]) -> None:
509-
if len(node.args) >= 2:
510-
node.update_arg(1, start_permute[cast(int, node.args[1])])
511-
else:
512-
node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])])
492+
dim = get_arg(node, "dim", int)
493+
set_arg(node, "dim", start_permute[dim])
513494

514495
def update_view_copy(self, node: torch.fx.Node, start_permute: list[int]) -> None:
515496
"""Adjust view_copy shape arg after permute removal.
516497
517498
After removing the start permute, the view's input is in the original
518499
(un-permuted) layout. Recompute the view's target shape accordingly.
519500
"""
520-
if node.meta.get("val") is None:
521-
return
522501
inp = node.args[0]
523-
if not isinstance(inp, torch.fx.Node) or inp.meta.get("val") is None:
524-
return
502+
assert isinstance(inp, torch.fx.Node)
525503

526504
in_shape = inp.meta["val"].shape
527505
out_shape = node.meta["val"].shape

0 commit comments

Comments
 (0)