@@ -246,9 +246,55 @@ def test_dim_order_from_stride(self) -> None:
246246 # dim[2] is broadcasting dim
247247 # shape = (5, 1, 15, 10)
248248 strides = (10 , 10 , 0 , 1 )
249- with self .assertRaises (ValueError ):
249+ # torch._check raises RuntimeError on concrete 0.
250+ with self .assertRaises (RuntimeError ):
250251 dim_order = dim_order_from_stride (strides )
251252
253+ def test_dim_order_from_stride_unbacked (self ) -> None :
254+ """
255+ dim_order_from_stride should produce a sane permutation even when the
256+ strides contain unbacked SymInts. The comparator falls back to
257+ divisibility-based reasoning so common cases like (1, u0) and
258+ (u0, 2 * u0) order correctly.
259+ """
260+ from torch .fx .experimental .symbolic_shapes import (
261+ GuardOnDataDependentSymNode ,
262+ ShapeEnv ,
263+ )
264+
265+ shape_env = ShapeEnv ()
266+ u0 = shape_env .create_unbacked_symint ()
267+ u1 = shape_env .create_unbacked_symint ()
268+
269+ # 1 < u0 should be True via divisibility (u0 % 1 == 0) + optimistic
270+ # `1 != u0`. Descending sort puts u0 outer, stride 1 inner.
271+ dim_order = dim_order_from_stride ((1 , u0 ))
272+ self .assertEqual ((1 , 0 ), dim_order )
273+
274+ # u0 < 2 * u0 should be True via divisibility ((2*u0) % u0 == 0) and
275+ # provable inequality (u0 != 0 after torch._check).
276+ dim_order = dim_order_from_stride ((u0 , 2 * u0 ))
277+ self .assertEqual ((1 , 0 ), dim_order )
278+
279+ # Mixed concrete + symbolic: (1, u0, 2 * u0). Descending stride order
280+ # is (2*u0, u0, 1) -> indices (2, 1, 0).
281+ dim_order = dim_order_from_stride ((1 , u0 , 2 * u0 ))
282+ self .assertEqual ((2 , 1 , 0 ), dim_order )
283+
284+ # u0 < u1 (independent unbackeds) is genuinely ambiguous, so the
285+ # comparator must fall back to the raw `<` and raise DDE.
286+ with self .assertRaises (GuardOnDataDependentSymNode ):
287+ dim_order_from_stride ((u0 , u1 ))
288+
289+ # u0 < u0 is False both ways (symmetric); stable sort preserves order.
290+ dim_order = dim_order_from_stride ((u0 , u0 ))
291+ self .assertEqual ((0 , 1 ), dim_order )
292+
293+ # Unbacked stride of 0 (concrete 0 mixed with unbacked) -> RuntimeError
294+ # via torch._check.
295+ with self .assertRaises (RuntimeError ):
296+ dim_order_from_stride ((u0 , 0 , 1 ))
297+
252298 def test_strides_from_dim_order (self ) -> None :
253299 sizes = []
254300 dim_order = []
0 commit comments