Skip to content

[BUG] #3327

@harshvardhanagg

Description

@harshvardhanagg

Which component has the problem?

CuTe DSL

Bug Report

Version

  • nvidia-cutlass-dsl 4.5.2 (latest stable on PyPI; also present in 4.5.0/4.5.1)
  • JAX integration (cutlass.jax.cutlass_call)
  • Target: SM100 (Blackwell), Python 3.13, CUDA 12.x

Summary

When a JAX input tensor has any stride ≥ 2³¹ elements and the kernel is wrapped
with cutlass_call(..., use_static_tensors=False), compilation crashes while the
dynamic-stride layout is being built. The divisibility/alignment assumption for that
stride is passed into ConstrainedIntType.get(divisible_by, width), whose
divisible_by argument is bound as a C++ int32, so any value ≥ 2³¹ raises
TypeError: incompatible function arguments.

This is especially surprising when the offending stride belongs to a size-1
leading dimension
(a degenerate axis that is never indexed): its stride equals the
size of the entire inner block, so a contiguous [1, N, H, D] tensor with
N·H·D ≥ 2³¹ crashes even though batch index 0 is the only one ever used.

Related (runtime int32-stride limitation): #2312. This report is the
compile-time cutlass.jax divisibility path, which I don't think is covered
there.

Reproducer

import jax, jax.numpy as jnp
import cutlass, cutlass.cute as cute
from cutlass.jax import cutlass_call
import cuda.bindings.driver as cuda_driver

# Leading (size-1) axis stride = N*H*D = 839680*20*128 = 2_149_580_800 > 2**31
B, N, H, D = 1, 839680, 20, 128

@cute.jit
def launch(stream: cuda_driver.CUstream, mX: cute.Tensor, mO: cute.Tensor):
    # body irrelevant; crash happens before this runs, while building the
    # dynamic-stride layout for the inputs.
    ...

call = cutlass_call(
    launch,
    output_shape_dtype=[jax.ShapeDtypeStruct((B, N, H, D), jnp.bfloat16)],
    use_static_tensors=False,   # <-- True compiles fine; False crashes
)

x = jnp.zeros((B, N, H, D), jnp.bfloat16)
call(x)   # raises during compilation

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions