Skip to content

Commit 24dea61

Browse files
DrJessopfacebook-github-bot
authored andcommitted
Reorder slice before view
Summary: The closer slice is to compute, the easier it is to perform certain optimizations you couldn't previously. Have seen cases where we have linear -> view -> slice nodes, and if those slice nodes were right after the linear, we could have sliced out the channel dim directly in those weights at compile time rather than hitting runtime non-contiguous slice performance penalties. Differential Revision: D108217652
1 parent d7ca5db commit 24dea61

2 files changed

Lines changed: 450 additions & 1 deletion

File tree

backends/cadence/aot/reorder_ops.py

Lines changed: 237 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from collections import defaultdict
1313
from math import prod
14-
from typing import Callable, cast, DefaultDict, List, Tuple
14+
from typing import Callable, cast, DefaultDict, List, Optional, Tuple
1515

1616
import torch
1717
import torch.fx
@@ -781,6 +781,242 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
781781
return True
782782

783783

784+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
785+
class MoveSliceBeforeViewPass(RemoveOrReplacePassInterface):
786+
"""Move a slice_copy above a view_copy when the slice is re-expressible as a
787+
single slice on one dim of the pre-view tensor.
788+
789+
Rewrites view(x) -> slice(dim=d, start, end, step) into
790+
slice(x, dim=d', start', end', step') -> view(sliced, slice_out_shape), so the
791+
slice lands directly on x. This may be useful in attention patterns, where
792+
we view outputs of a large linear into a new shape where the number of
793+
attention heads are the last dim, and we need to run independent computation
794+
per head. Moving the slice before the view can allow us to then directly slice
795+
the constant linear weights.
796+
797+
A view is a contiguous reshape: it never moves or reorders elements, it only
798+
re-groups the shared row-major index space into different dims. A slice keeps
799+
an arithmetic progression of indices (start, start+step, ...) along one viewed
800+
dim, and that progression collapses back to a *single* slice on one pre-view
801+
dim exactly when the row-major strides line up. ``_derive_pre_view_slice``
802+
handles the three cases that qualify:
803+
804+
* untouched dim: the viewed dim is left unchanged by the view -- same size
805+
and same inner stride as some pre-view dim -- so the slice copies over
806+
verbatim (any step).
807+
* contiguous: the viewed dim and a pre-view dim span the same flat extent
808+
(a split's outermost factor, or a merge that aligns), so a contiguous
809+
(step==1) slice maps to a contiguous pre-view slice.
810+
* strided: the viewed dim is an innermost factor of a pre-view dim
811+
(identical inner stride) selected width-1, so it maps to a strided
812+
pre-view slice with step == the viewed dim's size.
813+
814+
Everything else -- middle factors, wider strided selections -- is block-strided
815+
(runs separated by gaps), which no single slice can express, so it is left
816+
unchanged.
817+
818+
Each slice is handled independently, so a view that fans out to several slices
819+
is rewritten one slice at a time and the now-dead view is removed by dead-code
820+
elimination -- there is no single-user requirement on the view.
821+
"""
822+
823+
@property
824+
def targets(self) -> list[EdgeOpOverload]:
825+
return [exir_ops.edge.aten.slice_copy.Tensor]
826+
827+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
828+
view_node = get_arg(node, "input", torch.fx.Node)
829+
if view_node.target != exir_ops.edge.aten.view_copy.default:
830+
return False
831+
832+
x_node = get_arg(view_node, "input", torch.fx.Node)
833+
pre_view_shape = tuple(x_node.meta["val"].shape)
834+
post_view_shape = tuple(view_node.meta["val"].shape)
835+
# The rewrite is pure index arithmetic; symbolic or empty shapes are out.
836+
if not all(isinstance(d, int) for d in pre_view_shape + post_view_shape):
837+
return False
838+
if 0 in pre_view_shape or 0 in post_view_shape:
839+
return False
840+
841+
dim = get_arg(node, "dim", int)
842+
if dim < 0:
843+
dim += len(post_view_shape)
844+
post_view_size = post_view_shape[dim]
845+
846+
bounds = self._normalize_slice(node, post_view_size)
847+
if bounds is None:
848+
return False
849+
start, stop, step, post_view_count = bounds
850+
851+
# Row-major stride of the sliced viewed dim, and of every pre-view dim.
852+
post_view_stride = prod(post_view_shape[dim + 1 :])
853+
pre_view_strides = self._row_major_strides(pre_view_shape)
854+
855+
derived = self._derive_pre_view_slice(
856+
pre_view_shape,
857+
pre_view_strides,
858+
post_view_stride,
859+
post_view_size,
860+
start,
861+
stop,
862+
step,
863+
post_view_count,
864+
)
865+
if derived is None:
866+
return False
867+
pre_view_dim, pre_view_start, pre_view_stop, pre_view_step = derived
868+
869+
slice_out_shape = tuple(node.meta["val"].shape)
870+
graph = node.graph
871+
with graph.inserting_before(node):
872+
new_slice_args = (
873+
x_node,
874+
pre_view_dim,
875+
pre_view_start,
876+
pre_view_stop,
877+
pre_view_step,
878+
)
879+
new_slice = graph.create_node(
880+
"call_function",
881+
exir_ops.edge.aten.slice_copy.Tensor,
882+
args=new_slice_args,
883+
)
884+
new_slice.meta["val"] = exir_ops.edge.aten.slice_copy.Tensor(
885+
x_node.meta["val"], *new_slice_args[1:]
886+
)
887+
new_view = graph.create_node(
888+
"call_function",
889+
exir_ops.edge.aten.view_copy.default,
890+
args=(new_slice, list(slice_out_shape)),
891+
)
892+
new_view.meta["val"] = exir_ops.edge.aten.view_copy.default(
893+
new_slice.meta["val"], list(slice_out_shape)
894+
)
895+
896+
node.replace_all_uses_with(new_view)
897+
return True
898+
899+
@staticmethod
900+
def _row_major_strides(shape: tuple[int, ...]) -> list[int]:
901+
"""Row-major (contiguous) strides for ``shape``."""
902+
strides = [1] * len(shape)
903+
acc = 1
904+
for i in range(len(shape) - 1, -1, -1):
905+
strides[i] = acc
906+
acc *= shape[i]
907+
return strides
908+
909+
def _normalize_slice(
910+
self, node: torch.fx.Node, post_view_size: int
911+
) -> Optional[tuple[int, int, int, int]]:
912+
"""Resolve the slice to concrete, clamped ``(start, stop, step,
913+
post_view_count)`` ints, or None if the bounds are dynamic, the selection
914+
is empty, or the step is non-positive (none of which this pass handles)."""
915+
step = get_arg(node, "step")
916+
917+
if not isinstance(step, int):
918+
return None
919+
920+
if step <= 0:
921+
return None
922+
923+
raw_start = get_arg(node, "start")
924+
raw_stop = get_arg(node, "end")
925+
926+
# Make sure raw_start/raw_stop are not symbolic.
927+
if (raw_start is not None and not isinstance(raw_start, int)) or (
928+
raw_stop is not None and not isinstance(raw_stop, int)
929+
):
930+
return None
931+
932+
start = 0 if raw_start is None else raw_start
933+
stop = post_view_size if raw_stop is None else raw_stop
934+
if start < 0:
935+
start += post_view_size
936+
if stop < 0:
937+
stop += post_view_size
938+
start = max(0, min(start, post_view_size))
939+
stop = max(0, min(stop, post_view_size))
940+
if stop <= start:
941+
return None
942+
943+
post_view_count = (stop - start + step - 1) // step
944+
if post_view_count <= 0:
945+
return None
946+
return start, stop, step, post_view_count
947+
948+
def _derive_pre_view_slice(
949+
self,
950+
pre_view_shape: tuple[int, ...],
951+
pre_view_strides: list[int],
952+
post_view_stride: int,
953+
post_view_size: int,
954+
start: int,
955+
stop: int,
956+
step: int,
957+
post_view_count: int,
958+
) -> tuple[int, int, int, int] | None:
959+
"""Return ``(dim, start, stop, step)`` for the single pre-view-tensor slice
960+
equivalent to slicing the viewed dim, or None if no single pre-view slice
961+
reproduces it.
962+
963+
Both shapes index the same row-major flat space, so the sliced viewed dim
964+
(size ``post_view_size``, inner stride ``post_view_stride``) lines up with
965+
one pre-view dim (size ``pre_view_size``, inner stride ``pre_view_stride``)
966+
in one of three ways.
967+
"""
968+
for pre_view_dim, (pre_view_stride, pre_view_size) in enumerate(
969+
zip(pre_view_strides, pre_view_shape)
970+
):
971+
# Untouched: the viewed dim is identical to this pre-view dim (same
972+
# size and same inner stride), so the slice applies verbatim, any step.
973+
if pre_view_stride == post_view_stride and pre_view_size == post_view_size:
974+
return pre_view_dim, start, stop, step
975+
976+
# Contiguous: the viewed dim and this pre-view dim span the same flat
977+
# extent (same period), and the selected band aligns to this dim's
978+
# boundaries. A contiguous (step==1) viewed slice
979+
# [start, start+post_view_count) is the flat band [start*
980+
# post_view_stride, (start+post_view_count)*post_view_stride), a
981+
# contiguous slice on this pre-view dim iff both ends are multiples of
982+
# its stride.
983+
if (
984+
step == 1
985+
and post_view_size * post_view_stride == pre_view_size * pre_view_stride
986+
):
987+
flat_start = start * post_view_stride
988+
flat_stop = (start + post_view_count) * post_view_stride
989+
if (
990+
flat_start % pre_view_stride == 0
991+
and flat_stop % pre_view_stride == 0
992+
):
993+
return (
994+
pre_view_dim,
995+
flat_start // pre_view_stride,
996+
flat_stop // pre_view_stride,
997+
1,
998+
)
999+
1000+
# Strided is the ONLY way the reshape itself introduces a stride, and
1001+
# it requires a width-1 selection (post_view_count == 1): the viewed
1002+
# dim is an innermost factor of this pre-view dim (identical inner
1003+
# stride), so fixing that single factor index and letting the rest of
1004+
# the pre-view dim run yields a uniform stride equal to the viewed dim's
1005+
# size. Any wider selection (post_view_count > 1) of an inner factor
1006+
# leaves runs separated by gaps -- block-strided, not a single slice --
1007+
# so width-1 is required.
1008+
if (
1009+
post_view_count == 1
1010+
and post_view_size > 1
1011+
and pre_view_stride == post_view_stride
1012+
and pre_view_size % post_view_size == 0
1013+
):
1014+
pre_view_count = pre_view_size // post_view_size
1015+
pre_view_stop = start + (pre_view_count - 1) * post_view_size + 1
1016+
return pre_view_dim, start, pre_view_stop, post_view_size
1017+
return None
1018+
1019+
7841020
@register_cadence_pass(CadencePassAttribute(opt_level=1))
7851021
class PropagateSlice(RemoveOrReplacePassInterface):
7861022
"""Propagate slice_copy before element-wise ops when the cost model

0 commit comments

Comments
 (0)