Which component has the problem?
CuTe DSL
Bug Report
Version
nvidia-cutlass-dsl 4.5.2 (latest stable on PyPI; also present in 4.5.0/4.5.1)
- JAX integration (
cutlass.jax.cutlass_call)
- Target: SM100 (Blackwell), Python 3.13, CUDA 12.x
Summary
When a JAX input tensor has any stride ≥ 2³¹ elements and the kernel is wrapped
with cutlass_call(..., use_static_tensors=False), compilation crashes while the
dynamic-stride layout is being built. The divisibility/alignment assumption for that
stride is passed into ConstrainedIntType.get(divisible_by, width), whose
divisible_by argument is bound as a C++ int32, so any value ≥ 2³¹ raises
TypeError: incompatible function arguments.
This is especially surprising when the offending stride belongs to a size-1
leading dimension (a degenerate axis that is never indexed): its stride equals the
size of the entire inner block, so a contiguous [1, N, H, D] tensor with
N·H·D ≥ 2³¹ crashes even though batch index 0 is the only one ever used.
Related (runtime int32-stride limitation): #2312. This report is the
compile-time cutlass.jax divisibility path, which I don't think is covered
there.
Reproducer
import jax, jax.numpy as jnp
import cutlass, cutlass.cute as cute
from cutlass.jax import cutlass_call
import cuda.bindings.driver as cuda_driver
# Leading (size-1) axis stride = N*H*D = 839680*20*128 = 2_149_580_800 > 2**31
B, N, H, D = 1, 839680, 20, 128
@cute.jit
def launch(stream: cuda_driver.CUstream, mX: cute.Tensor, mO: cute.Tensor):
# body irrelevant; crash happens before this runs, while building the
# dynamic-stride layout for the inputs.
...
call = cutlass_call(
launch,
output_shape_dtype=[jax.ShapeDtypeStruct((B, N, H, D), jnp.bfloat16)],
use_static_tensors=False, # <-- True compiles fine; False crashes
)
x = jnp.zeros((B, N, H, D), jnp.bfloat16)
call(x) # raises during compilation
Which component has the problem?
CuTe DSL
Bug Report
Version
nvidia-cutlass-dsl4.5.2 (latest stable on PyPI; also present in 4.5.0/4.5.1)cutlass.jax.cutlass_call)Summary
When a JAX input tensor has any stride ≥ 2³¹ elements and the kernel is wrapped
with
cutlass_call(..., use_static_tensors=False), compilation crashes while thedynamic-stride layout is being built. The divisibility/alignment assumption for that
stride is passed into
ConstrainedIntType.get(divisible_by, width), whosedivisible_byargument is bound as a C++ int32, so any value ≥ 2³¹ raisesTypeError: incompatible function arguments.This is especially surprising when the offending stride belongs to a size-1
leading dimension (a degenerate axis that is never indexed): its stride equals the
size of the entire inner block, so a contiguous
[1, N, H, D]tensor withN·H·D ≥ 2³¹crashes even though batch index 0 is the only one ever used.Related (runtime int32-stride limitation): #2312. This report is the
compile-time
cutlass.jaxdivisibility path, which I don't think is coveredthere.
Reproducer