|
11 | 11 |
|
12 | 12 | from collections import defaultdict |
13 | 13 | from math import prod |
14 | | -from typing import Callable, cast, DefaultDict, List, Tuple |
| 14 | +from typing import Callable, cast, DefaultDict, List, Optional, Tuple |
15 | 15 |
|
16 | 16 | import torch |
17 | 17 | import torch.fx |
@@ -781,6 +781,242 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: |
781 | 781 | return True |
782 | 782 |
|
783 | 783 |
|
| 784 | +@register_cadence_pass(CadencePassAttribute(opt_level=1)) |
| 785 | +class MoveSliceBeforeViewPass(RemoveOrReplacePassInterface): |
| 786 | + """Move a slice_copy above a view_copy when the slice is re-expressible as a |
| 787 | + single slice on one dim of the pre-view tensor. |
| 788 | +
|
| 789 | + Rewrites view(x) -> slice(dim=d, start, end, step) into |
| 790 | + slice(x, dim=d', start', end', step') -> view(sliced, slice_out_shape), so the |
| 791 | + slice lands directly on x. This may be useful in attention patterns, where |
| 792 | + we view outputs of a large linear into a new shape where the number of |
| 793 | + attention heads are the last dim, and we need to run independent computation |
| 794 | + per head. Moving the slice before the view can allow us to then directly slice |
| 795 | + the constant linear weights. |
| 796 | +
|
| 797 | + A view is a contiguous reshape: it never moves or reorders elements, it only |
| 798 | + re-groups the shared row-major index space into different dims. A slice keeps |
| 799 | + an arithmetic progression of indices (start, start+step, ...) along one viewed |
| 800 | + dim, and that progression collapses back to a *single* slice on one pre-view |
| 801 | + dim exactly when the row-major strides line up. ``_derive_pre_view_slice`` |
| 802 | + handles the three cases that qualify: |
| 803 | +
|
| 804 | + * untouched dim: the viewed dim is left unchanged by the view -- same size |
| 805 | + and same inner stride as some pre-view dim -- so the slice copies over |
| 806 | + verbatim (any step). |
| 807 | + * contiguous: the viewed dim and a pre-view dim span the same flat extent |
| 808 | + (a split's outermost factor, or a merge that aligns), so a contiguous |
| 809 | + (step==1) slice maps to a contiguous pre-view slice. |
| 810 | + * strided: the viewed dim is an innermost factor of a pre-view dim |
| 811 | + (identical inner stride) selected width-1, so it maps to a strided |
| 812 | + pre-view slice with step == the viewed dim's size. |
| 813 | +
|
| 814 | + Everything else -- middle factors, wider strided selections -- is block-strided |
| 815 | + (runs separated by gaps), which no single slice can express, so it is left |
| 816 | + unchanged. |
| 817 | +
|
| 818 | + Each slice is handled independently, so a view that fans out to several slices |
| 819 | + is rewritten one slice at a time and the now-dead view is removed by dead-code |
| 820 | + elimination -- there is no single-user requirement on the view. |
| 821 | + """ |
| 822 | + |
| 823 | + @property |
| 824 | + def targets(self) -> list[EdgeOpOverload]: |
| 825 | + return [exir_ops.edge.aten.slice_copy.Tensor] |
| 826 | + |
| 827 | + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: |
| 828 | + view_node = get_arg(node, "input", torch.fx.Node) |
| 829 | + if view_node.target != exir_ops.edge.aten.view_copy.default: |
| 830 | + return False |
| 831 | + |
| 832 | + x_node = get_arg(view_node, "input", torch.fx.Node) |
| 833 | + pre_view_shape = tuple(x_node.meta["val"].shape) |
| 834 | + post_view_shape = tuple(view_node.meta["val"].shape) |
| 835 | + # The rewrite is pure index arithmetic; symbolic or empty shapes are out. |
| 836 | + if not all(isinstance(d, int) for d in pre_view_shape + post_view_shape): |
| 837 | + return False |
| 838 | + if 0 in pre_view_shape or 0 in post_view_shape: |
| 839 | + return False |
| 840 | + |
| 841 | + dim = get_arg(node, "dim", int) |
| 842 | + if dim < 0: |
| 843 | + dim += len(post_view_shape) |
| 844 | + post_view_size = post_view_shape[dim] |
| 845 | + |
| 846 | + bounds = self._normalize_slice(node, post_view_size) |
| 847 | + if bounds is None: |
| 848 | + return False |
| 849 | + start, stop, step, post_view_count = bounds |
| 850 | + |
| 851 | + # Row-major stride of the sliced viewed dim, and of every pre-view dim. |
| 852 | + post_view_stride = prod(post_view_shape[dim + 1 :]) |
| 853 | + pre_view_strides = self._row_major_strides(pre_view_shape) |
| 854 | + |
| 855 | + derived = self._derive_pre_view_slice( |
| 856 | + pre_view_shape, |
| 857 | + pre_view_strides, |
| 858 | + post_view_stride, |
| 859 | + post_view_size, |
| 860 | + start, |
| 861 | + stop, |
| 862 | + step, |
| 863 | + post_view_count, |
| 864 | + ) |
| 865 | + if derived is None: |
| 866 | + return False |
| 867 | + pre_view_dim, pre_view_start, pre_view_stop, pre_view_step = derived |
| 868 | + |
| 869 | + slice_out_shape = tuple(node.meta["val"].shape) |
| 870 | + graph = node.graph |
| 871 | + with graph.inserting_before(node): |
| 872 | + new_slice_args = ( |
| 873 | + x_node, |
| 874 | + pre_view_dim, |
| 875 | + pre_view_start, |
| 876 | + pre_view_stop, |
| 877 | + pre_view_step, |
| 878 | + ) |
| 879 | + new_slice = graph.create_node( |
| 880 | + "call_function", |
| 881 | + exir_ops.edge.aten.slice_copy.Tensor, |
| 882 | + args=new_slice_args, |
| 883 | + ) |
| 884 | + new_slice.meta["val"] = exir_ops.edge.aten.slice_copy.Tensor( |
| 885 | + x_node.meta["val"], *new_slice_args[1:] |
| 886 | + ) |
| 887 | + new_view = graph.create_node( |
| 888 | + "call_function", |
| 889 | + exir_ops.edge.aten.view_copy.default, |
| 890 | + args=(new_slice, list(slice_out_shape)), |
| 891 | + ) |
| 892 | + new_view.meta["val"] = exir_ops.edge.aten.view_copy.default( |
| 893 | + new_slice.meta["val"], list(slice_out_shape) |
| 894 | + ) |
| 895 | + |
| 896 | + node.replace_all_uses_with(new_view) |
| 897 | + return True |
| 898 | + |
| 899 | + @staticmethod |
| 900 | + def _row_major_strides(shape: tuple[int, ...]) -> list[int]: |
| 901 | + """Row-major (contiguous) strides for ``shape``.""" |
| 902 | + strides = [1] * len(shape) |
| 903 | + acc = 1 |
| 904 | + for i in range(len(shape) - 1, -1, -1): |
| 905 | + strides[i] = acc |
| 906 | + acc *= shape[i] |
| 907 | + return strides |
| 908 | + |
| 909 | + def _normalize_slice( |
| 910 | + self, node: torch.fx.Node, post_view_size: int |
| 911 | + ) -> Optional[tuple[int, int, int, int]]: |
| 912 | + """Resolve the slice to concrete, clamped ``(start, stop, step, |
| 913 | + post_view_count)`` ints, or None if the bounds are dynamic, the selection |
| 914 | + is empty, or the step is non-positive (none of which this pass handles).""" |
| 915 | + step = get_arg(node, "step") |
| 916 | + |
| 917 | + if not isinstance(step, int): |
| 918 | + return None |
| 919 | + |
| 920 | + if step <= 0: |
| 921 | + return None |
| 922 | + |
| 923 | + raw_start = get_arg(node, "start") |
| 924 | + raw_stop = get_arg(node, "end") |
| 925 | + |
| 926 | + # Make sure raw_start/raw_stop are not symbolic. |
| 927 | + if (raw_start is not None and not isinstance(raw_start, int)) or ( |
| 928 | + raw_stop is not None and not isinstance(raw_stop, int) |
| 929 | + ): |
| 930 | + return None |
| 931 | + |
| 932 | + start = 0 if raw_start is None else raw_start |
| 933 | + stop = post_view_size if raw_stop is None else raw_stop |
| 934 | + if start < 0: |
| 935 | + start += post_view_size |
| 936 | + if stop < 0: |
| 937 | + stop += post_view_size |
| 938 | + start = max(0, min(start, post_view_size)) |
| 939 | + stop = max(0, min(stop, post_view_size)) |
| 940 | + if stop <= start: |
| 941 | + return None |
| 942 | + |
| 943 | + post_view_count = (stop - start + step - 1) // step |
| 944 | + if post_view_count <= 0: |
| 945 | + return None |
| 946 | + return start, stop, step, post_view_count |
| 947 | + |
| 948 | + def _derive_pre_view_slice( |
| 949 | + self, |
| 950 | + pre_view_shape: tuple[int, ...], |
| 951 | + pre_view_strides: list[int], |
| 952 | + post_view_stride: int, |
| 953 | + post_view_size: int, |
| 954 | + start: int, |
| 955 | + stop: int, |
| 956 | + step: int, |
| 957 | + post_view_count: int, |
| 958 | + ) -> tuple[int, int, int, int] | None: |
| 959 | + """Return ``(dim, start, stop, step)`` for the single pre-view-tensor slice |
| 960 | + equivalent to slicing the viewed dim, or None if no single pre-view slice |
| 961 | + reproduces it. |
| 962 | +
|
| 963 | + Both shapes index the same row-major flat space, so the sliced viewed dim |
| 964 | + (size ``post_view_size``, inner stride ``post_view_stride``) lines up with |
| 965 | + one pre-view dim (size ``pre_view_size``, inner stride ``pre_view_stride``) |
| 966 | + in one of three ways. |
| 967 | + """ |
| 968 | + for pre_view_dim, (pre_view_stride, pre_view_size) in enumerate( |
| 969 | + zip(pre_view_strides, pre_view_shape) |
| 970 | + ): |
| 971 | + # Untouched: the viewed dim is identical to this pre-view dim (same |
| 972 | + # size and same inner stride), so the slice applies verbatim, any step. |
| 973 | + if pre_view_stride == post_view_stride and pre_view_size == post_view_size: |
| 974 | + return pre_view_dim, start, stop, step |
| 975 | + |
| 976 | + # Contiguous: the viewed dim and this pre-view dim span the same flat |
| 977 | + # extent (same period), and the selected band aligns to this dim's |
| 978 | + # boundaries. A contiguous (step==1) viewed slice |
| 979 | + # [start, start+post_view_count) is the flat band [start* |
| 980 | + # post_view_stride, (start+post_view_count)*post_view_stride), a |
| 981 | + # contiguous slice on this pre-view dim iff both ends are multiples of |
| 982 | + # its stride. |
| 983 | + if ( |
| 984 | + step == 1 |
| 985 | + and post_view_size * post_view_stride == pre_view_size * pre_view_stride |
| 986 | + ): |
| 987 | + flat_start = start * post_view_stride |
| 988 | + flat_stop = (start + post_view_count) * post_view_stride |
| 989 | + if ( |
| 990 | + flat_start % pre_view_stride == 0 |
| 991 | + and flat_stop % pre_view_stride == 0 |
| 992 | + ): |
| 993 | + return ( |
| 994 | + pre_view_dim, |
| 995 | + flat_start // pre_view_stride, |
| 996 | + flat_stop // pre_view_stride, |
| 997 | + 1, |
| 998 | + ) |
| 999 | + |
| 1000 | + # Strided is the ONLY way the reshape itself introduces a stride, and |
| 1001 | + # it requires a width-1 selection (post_view_count == 1): the viewed |
| 1002 | + # dim is an innermost factor of this pre-view dim (identical inner |
| 1003 | + # stride), so fixing that single factor index and letting the rest of |
| 1004 | + # the pre-view dim run yields a uniform stride equal to the viewed dim's |
| 1005 | + # size. Any wider selection (post_view_count > 1) of an inner factor |
| 1006 | + # leaves runs separated by gaps -- block-strided, not a single slice -- |
| 1007 | + # so width-1 is required. |
| 1008 | + if ( |
| 1009 | + post_view_count == 1 |
| 1010 | + and post_view_size > 1 |
| 1011 | + and pre_view_stride == post_view_stride |
| 1012 | + and pre_view_size % post_view_size == 0 |
| 1013 | + ): |
| 1014 | + pre_view_count = pre_view_size // post_view_size |
| 1015 | + pre_view_stop = start + (pre_view_count - 1) * post_view_size + 1 |
| 1016 | + return pre_view_dim, start, pre_view_stop, post_view_size |
| 1017 | + return None |
| 1018 | + |
| 1019 | + |
784 | 1020 | @register_cadence_pass(CadencePassAttribute(opt_level=1)) |
785 | 1021 | class PropagateSlice(RemoveOrReplacePassInterface): |
786 | 1022 | """Propagate slice_copy before element-wise ops when the cost model |
|
0 commit comments