Skip to content

Commit b36f218

Browse files
DrJessopfacebook-github-bot
authored andcommitted
Reorder slice before view (#20240)
Summary: The closer slice is to compute, the easier it is to perform certain optimizations you couldn't previously. Have seen cases where we have linear -> view -> slice nodes, and if those slice nodes were right after the linear, we could have sliced out the channel dim directly in those weights at compile time rather than hitting runtime non-contiguous slice performance penalties. Reviewed By: abeakkas Differential Revision: D108217652
1 parent 185bd09 commit b36f218

2 files changed

Lines changed: 447 additions & 1 deletion

File tree

backends/cadence/aot/reorder_ops.py

Lines changed: 234 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from collections import defaultdict
1313
from math import prod
14-
from typing import Callable, cast, DefaultDict, List, Tuple
14+
from typing import Callable, cast, DefaultDict, List, Optional, Tuple
1515

1616
import torch
1717
import torch.fx
@@ -781,6 +781,239 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
781781
return True
782782

783783

784+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
785+
class MoveSliceBeforeViewPass(RemoveOrReplacePassInterface):
786+
"""Move a slice_copy above a view_copy when the slice is re-expressible as a
787+
single slice on one dim of the pre-view tensor.
788+
789+
Rewrites view(x) -> slice(dim=d, start, end, step) into
790+
slice(x, dim=d', start', end', step') -> view(sliced, slice_out_shape), so the
791+
slice lands directly on x. This may be useful in attention patterns, where
792+
we view outputs of a large linear into a new shape where the number of
793+
attention heads are the last dim, and we need to run independent computation
794+
per head. Moving the slice before the view can allow us to then directly slice
795+
the constant linear weights.
796+
797+
A view is a contiguous reshape: it never moves or reorders elements, it only
798+
re-groups the shared row-major index space into different dims. A slice keeps
799+
an arithmetic progression of indices (start, start+step, ...) along one viewed
800+
dim, and that progression collapses back to a *single* slice on one pre-view
801+
dim exactly when the row-major strides line up. ``_derive_pre_view_slice``
802+
handles the three cases that qualify:
803+
804+
* untouched dim: the viewed dim is left unchanged by the view -- same size
805+
and same inner stride as some pre-view dim -- so the slice copies over
806+
verbatim (any step).
807+
* contiguous: the viewed dim and a pre-view dim span the same flat extent
808+
(a split's outermost factor, or a merge that aligns), so a contiguous
809+
(step==1) slice maps to a contiguous pre-view slice.
810+
* strided: the viewed dim is an innermost factor of a pre-view dim
811+
(identical inner stride) selected width-1, so it maps to a strided
812+
pre-view slice with step == the viewed dim's size.
813+
814+
Everything else -- middle factors, wider strided selections -- is block-strided
815+
(runs separated by gaps), which no single slice can express, so it is left
816+
unchanged.
817+
818+
Each slice is handled independently, so a view that fans out to several slices
819+
is rewritten one slice at a time and the now-dead view is removed by dead-code
820+
elimination -- there is no single-user requirement on the view.
821+
"""
822+
823+
@property
824+
def targets(self) -> list[EdgeOpOverload]:
825+
return [exir_ops.edge.aten.slice_copy.Tensor]
826+
827+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
828+
view_node = get_arg(node, "input", torch.fx.Node)
829+
if view_node.target != exir_ops.edge.aten.view_copy.default:
830+
return False
831+
832+
x_node = get_arg(view_node, "input", torch.fx.Node)
833+
pre_view_shape = tuple(x_node.meta["val"].shape)
834+
post_view_shape = tuple(view_node.meta["val"].shape)
835+
if 0 in pre_view_shape or 0 in post_view_shape:
836+
return False
837+
838+
dim = get_arg(node, "dim", int)
839+
if dim < 0:
840+
dim += len(post_view_shape)
841+
post_view_size = post_view_shape[dim]
842+
843+
bounds = self._normalize_slice(node, post_view_size)
844+
if bounds is None:
845+
return False
846+
start, stop, step = bounds
847+
848+
# The slice's own output shape gives the selected-element count along the
849+
# sliced dim directly -- it is exactly output_shape[dim].
850+
slice_out_shape = tuple(node.meta["val"].shape)
851+
post_view_count = slice_out_shape[dim]
852+
if post_view_count == 0:
853+
return False
854+
855+
# Row-major stride of the sliced viewed dim, and of every pre-view dim.
856+
post_view_stride = prod(post_view_shape[dim + 1 :])
857+
pre_view_strides = self._row_major_strides(pre_view_shape)
858+
859+
derived = self._derive_pre_view_slice(
860+
pre_view_shape,
861+
pre_view_strides,
862+
post_view_stride,
863+
post_view_size,
864+
start,
865+
stop,
866+
step,
867+
post_view_count,
868+
)
869+
if derived is None:
870+
return False
871+
pre_view_dim, pre_view_start, pre_view_stop, pre_view_step = derived
872+
873+
graph = node.graph
874+
with graph.inserting_before(node):
875+
new_slice_args = (
876+
x_node,
877+
pre_view_dim,
878+
pre_view_start,
879+
pre_view_stop,
880+
pre_view_step,
881+
)
882+
new_slice = graph.create_node(
883+
"call_function",
884+
exir_ops.edge.aten.slice_copy.Tensor,
885+
args=new_slice_args,
886+
)
887+
new_slice.meta["val"] = exir_ops.edge.aten.slice_copy.Tensor(
888+
x_node.meta["val"], *new_slice_args[1:]
889+
)
890+
new_view = graph.create_node(
891+
"call_function",
892+
exir_ops.edge.aten.view_copy.default,
893+
args=(new_slice, list(slice_out_shape)),
894+
)
895+
new_view.meta["val"] = exir_ops.edge.aten.view_copy.default(
896+
new_slice.meta["val"], list(slice_out_shape)
897+
)
898+
899+
node.replace_all_uses_with(new_view)
900+
return True
901+
902+
@staticmethod
903+
def _row_major_strides(shape: tuple[int, ...]) -> list[int]:
904+
"""Row-major (contiguous) strides for ``shape``."""
905+
strides = [1] * len(shape)
906+
acc = 1
907+
for i in range(len(shape) - 1, -1, -1):
908+
strides[i] = acc
909+
acc *= shape[i]
910+
return strides
911+
912+
def _normalize_slice(
913+
self, node: torch.fx.Node, post_view_size: int
914+
) -> Optional[tuple[int, int, int]]:
915+
"""Resolve the slice to concrete, clamped ``(start, stop, step)`` ints, or
916+
None if the bounds are dynamic or the step is non-positive (neither of
917+
which this pass handles)."""
918+
step = get_arg(node, "step")
919+
920+
if not isinstance(step, int):
921+
return None
922+
923+
if step <= 0:
924+
return None
925+
926+
raw_start = get_arg(node, "start")
927+
raw_stop = get_arg(node, "end")
928+
929+
# Make sure raw_start/raw_stop are not symbolic.
930+
if (raw_start is not None and not isinstance(raw_start, int)) or (
931+
raw_stop is not None and not isinstance(raw_stop, int)
932+
):
933+
return None
934+
935+
start = 0 if raw_start is None else raw_start
936+
stop = post_view_size if raw_stop is None else raw_stop
937+
if start < 0:
938+
start += post_view_size
939+
if stop < 0:
940+
stop += post_view_size
941+
start = max(0, min(start, post_view_size))
942+
stop = max(0, min(stop, post_view_size))
943+
return start, stop, step
944+
945+
def _derive_pre_view_slice(
946+
self,
947+
pre_view_shape: tuple[int, ...],
948+
pre_view_strides: list[int],
949+
post_view_stride: int,
950+
post_view_size: int,
951+
start: int,
952+
stop: int,
953+
step: int,
954+
post_view_count: int,
955+
) -> tuple[int, int, int, int] | None:
956+
"""Return ``(dim, start, stop, step)`` for the single pre-view-tensor slice
957+
equivalent to slicing the viewed dim, or None if no single pre-view slice
958+
reproduces it.
959+
960+
Both shapes index the same row-major flat space, so the sliced viewed dim
961+
(size ``post_view_size``, inner stride ``post_view_stride``) lines up with
962+
one pre-view dim (size ``pre_view_size``, inner stride ``pre_view_stride``)
963+
in one of three ways.
964+
"""
965+
for pre_view_dim, (pre_view_stride, pre_view_size) in enumerate(
966+
zip(pre_view_strides, pre_view_shape)
967+
):
968+
# Untouched: the viewed dim is identical to this pre-view dim (same
969+
# size and same inner stride), so the slice applies verbatim, any step.
970+
if pre_view_stride == post_view_stride and pre_view_size == post_view_size:
971+
return pre_view_dim, start, stop, step
972+
973+
# Contiguous: the viewed dim and this pre-view dim span the same flat
974+
# extent (same period), and the selected band aligns to this dim's
975+
# boundaries. A contiguous (step==1) viewed slice
976+
# [start, start+post_view_count) is the flat band [start*
977+
# post_view_stride, (start+post_view_count)*post_view_stride), a
978+
# contiguous slice on this pre-view dim iff both ends are multiples of
979+
# its stride.
980+
if (
981+
step == 1
982+
and post_view_size * post_view_stride == pre_view_size * pre_view_stride
983+
):
984+
flat_start = start * post_view_stride
985+
flat_stop = (start + post_view_count) * post_view_stride
986+
if (
987+
flat_start % pre_view_stride == 0
988+
and flat_stop % pre_view_stride == 0
989+
):
990+
return (
991+
pre_view_dim,
992+
flat_start // pre_view_stride,
993+
flat_stop // pre_view_stride,
994+
1,
995+
)
996+
997+
# Strided is the ONLY way the reshape itself introduces a stride, and
998+
# it requires a width-1 selection (post_view_count == 1): the viewed
999+
# dim is an innermost factor of this pre-view dim (identical inner
1000+
# stride), so fixing that single factor index and letting the rest of
1001+
# the pre-view dim run yields a uniform stride equal to the viewed dim's
1002+
# size. Any wider selection (post_view_count > 1) of an inner factor
1003+
# leaves runs separated by gaps -- block-strided, not a single slice --
1004+
# so width-1 is required.
1005+
if (
1006+
post_view_count == 1
1007+
and post_view_size > 1
1008+
and pre_view_stride == post_view_stride
1009+
and pre_view_size % post_view_size == 0
1010+
):
1011+
pre_view_count = pre_view_size // post_view_size
1012+
pre_view_stop = start + (pre_view_count - 1) * post_view_size + 1
1013+
return pre_view_dim, start, pre_view_stop, post_view_size
1014+
return None
1015+
1016+
7841017
@register_cadence_pass(CadencePassAttribute(opt_level=1))
7851018
class PropagateSlice(RemoveOrReplacePassInterface):
7861019
"""Propagate slice_copy before element-wise ops when the cost model

0 commit comments

Comments
 (0)