Skip to content

Commit 4cdb6cb

Browse files
laithsakkafacebook-github-bot
authored andcommitted
remove deprecated guard_size_oblivious from stride sorting logic in exir. (#19516)
Summary: guard_size_oblivious is deprecated use an explicit logic. Differential Revision: D104701778
1 parent 3d86cc7 commit 4cdb6cb

2 files changed

Lines changed: 64 additions & 18 deletions

File tree

exir/tensor.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -69,30 +69,30 @@ def dim_order_from_stride(stride: Tuple[int]) -> Tuple[bytes]:
6969
"""
7070
from torch.fx.experimental.symbolic_shapes import (
7171
guard_or_false,
72-
guard_size_oblivious,
72+
guard_or_true,
7373
)
7474

75-
for _, s in enumerate(stride):
76-
if guard_or_false(s == 0):
77-
raise ValueError("0 in strides is not supported for ExecuTorch.")
75+
for s in stride:
76+
torch._check(s != 0, lambda: "0 in strides is not supported for ExecuTorch.")
7877

7978
class K(NamedTuple):
8079
stride: int
8180

8281
def __lt__(self, other):
83-
return guard_size_oblivious(self.stride < other.stride)
84-
85-
def __gt__(self, other):
86-
return guard_size_oblivious(self.stride > other.stride)
87-
88-
def __le__(self, other):
89-
return guard_size_oblivious(self.stride <= other.stride)
90-
91-
def __ge__(self, other):
92-
return guard_size_oblivious(self.stride >= other.stride)
93-
94-
def __eq__(self, other):
95-
return guard_size_oblivious(self.stride == other.stride)
82+
# For backed/concrete strides this is practically a `<` operation.
83+
# For unbacked, we return True if `<` is statically known, then
84+
# try to answer symbolically with stride-ordering semantics:
85+
# u0 < u0 -> False
86+
# u0 < u1 (no info) -> DDE
87+
# u0 < 2 * u0 -> True (divisibility)
88+
# 1 < u0 -> True (1 divides anything)
89+
if guard_or_false(self.stride < other.stride):
90+
return True # statically known inequality
91+
if guard_or_false(other.stride % self.stride == 0) and guard_or_true(
92+
self.stride != other.stride
93+
):
94+
return True # symbolic inequality (e.g. u0 < 2048 * u0)
95+
return self.stride < other.stride
9696

9797
sorted_dims = [
9898
i[0] for i in sorted(enumerate(stride), key=lambda x: K(x[1]), reverse=True)

exir/tests/test_tensor.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,55 @@ def test_dim_order_from_stride(self) -> None:
246246
# dim[2] is broadcasting dim
247247
# shape = (5, 1, 15, 10)
248248
strides = (10, 10, 0, 1)
249-
with self.assertRaises(ValueError):
249+
# torch._check raises RuntimeError on concrete 0.
250+
with self.assertRaises(RuntimeError):
250251
dim_order = dim_order_from_stride(strides)
251252

253+
def test_dim_order_from_stride_unbacked(self) -> None:
254+
"""
255+
dim_order_from_stride should produce a sane permutation even when the
256+
strides contain unbacked SymInts. The comparator falls back to
257+
divisibility-based reasoning so common cases like (1, u0) and
258+
(u0, 2 * u0) order correctly.
259+
"""
260+
from torch.fx.experimental.symbolic_shapes import (
261+
GuardOnDataDependentSymNode,
262+
ShapeEnv,
263+
)
264+
265+
shape_env = ShapeEnv()
266+
u0 = shape_env.create_unbacked_symint()
267+
u1 = shape_env.create_unbacked_symint()
268+
269+
# 1 < u0 should be True via divisibility (u0 % 1 == 0) + optimistic
270+
# `1 != u0`. Descending sort puts u0 outer, stride 1 inner.
271+
dim_order = dim_order_from_stride((1, u0))
272+
self.assertEqual((1, 0), dim_order)
273+
274+
# u0 < 2 * u0 should be True via divisibility ((2*u0) % u0 == 0) and
275+
# provable inequality (u0 != 0 after torch._check).
276+
dim_order = dim_order_from_stride((u0, 2 * u0))
277+
self.assertEqual((1, 0), dim_order)
278+
279+
# Mixed concrete + symbolic: (1, u0, 2 * u0). Descending stride order
280+
# is (2*u0, u0, 1) -> indices (2, 1, 0).
281+
dim_order = dim_order_from_stride((1, u0, 2 * u0))
282+
self.assertEqual((2, 1, 0), dim_order)
283+
284+
# u0 < u1 (independent unbackeds) is genuinely ambiguous, so the
285+
# comparator must fall back to the raw `<` and raise DDE.
286+
with self.assertRaises(GuardOnDataDependentSymNode):
287+
dim_order_from_stride((u0, u1))
288+
289+
# u0 < u0 is False both ways (symmetric); stable sort preserves order.
290+
dim_order = dim_order_from_stride((u0, u0))
291+
self.assertEqual((0, 1), dim_order)
292+
293+
# Unbacked stride of 0 (concrete 0 mixed with unbacked) -> RuntimeError
294+
# via torch._check.
295+
with self.assertRaises(RuntimeError):
296+
dim_order_from_stride((u0, 0, 1))
297+
252298
def test_strides_from_dim_order(self) -> None:
253299
sizes = []
254300
dim_order = []

0 commit comments

Comments
 (0)