H100 CUDA Kernels for Gated Delta Net
- CUDA 12.9+ (for nvrtc), CUDA 12.3 for nvcc jit
- GCC 14+
- CMAKE 3.10+
- Python 3.10+
- Cutlass 4.2.0
- fmt
cmake -S . -B build only configures/generates build files. It does not compile.
Use either:
cmake -S . -B build
cmake --build build -j"$(nproc)"or the one-command wrapper:
./run_cmake.shTo build into a different folder (for example build_cmake):
./run_cmake.sh build_cmake- I put an example environment file in the root of the repository called
.example_env, with explanations of the variables used. You can copy it to.envand edit it to your needs.
The Python module name is gdn_cuda and maps directly to the C++ API in
gdn_cuda/include/gdn_cuda/api.hpp.
All APIs use the current CUDA stream from PyTorch.
| Function | Purpose | Inputs (shape) | DTypes | Returns |
|---|---|---|---|---|
init(library_root, cuda_home) |
Initialize JIT/compiler paths. | Host strings. | N/A | None |
chunked_forward(query, key, value, beta, gate, state=None, cu_seqlens=None, chunk_indices=None, cu_chunks=None) |
Chunked prefill forward pass. | padded: q/k (B,T,kh,sk), v (B,T,vh,sv), beta/gate (B,T,vh); varlen: q/k (N,kh,sk), v (N,vh,sv), beta/gate (N,vh), cu_seqlens (B+1), chunk_indices (2*chunks), cu_chunks (B+1) |
q/k/v/state: BF16, beta: BF16, gate: FP32 expected, metadata: INT32 |
(output, final_state) |
recurrent_forward(query, key, value, state, beta, gate, cu_seqlens=None, num_accepted_tokens=None, inference_mode=0, is_qk_norm=False) |
Recurrent decode/spec-verify forward pass. | padded: q/k/v (B,T,h,d), state (B,vh,sv,sk); varlen: q/k/v (N,h,d), cu_seqlens (B+1) |
recurrent path currently expects BF16 data tensors; metadata INT32 | (output, new_state) |
fused_gdn_gating(A_log, dt_bias, a, b, is_var_len=False) |
Compute (beta, gate) from raw GDN params. |
Kernel-defined compatible shapes. | A_log: FP32 or BF16 (current impl), others follow kernel path |
(beta, gate) |
| Function | Purpose | Inputs (shape) | DTypes | Mutation |
|---|---|---|---|---|
bf16_gdn_compute_u_w(k, v, u, w, beta, gate=None, compiled_dims="t", cu_seqlens=None, chunk_indices=None) |
Compute chunked U/W intermediates. | padded rank-4 or varlen rank-3 tensors. | k/v/u/w/beta: BF16, gate: FP32 MN-major, metadata INT32 |
Writes u, w |
bf16_chunked_seq_state_update(k, u, w, state, final_state, gate=None, is_initial_state=False, output_final_state=True, compiled_dims="t", cu_seqlens=None, chunk_size=64) |
Propagate per-chunk recurrent state. | padded rank-4/5 or varlen rank-3/4 tensors. | data BF16, gate FP32 MN-major |
Writes state, final_state, may update u |
bf16_gdn_chunked_compute_O(q, state, k, u, o, gate=None, compiled_dims="t", cu_seqlens=None, chunk_indices=None, cu_chunks=None, chunk_size=64) |
Compute output O from propagated state. |
padded rank-4/5 or varlen rank-3/4 + metadata | data BF16, gate FP32, metadata INT32 |
Writes o |
chunk_local_cumsum_bf16(input, output, batch_size, seq_len, num_heads, head_first=False, cu_seqlens=None, chunk_indices=None) |
Chunk-local inclusive cumsum helper. | rank depends on layout mode | input BF16, output float accumulation path | Writes output |
transpose_to_mn_major(input, alignment=16) |
Layout helper for MN-major kernel access. | rank >= 2 | FP32 or BF16 | Returns new tensor |
Note: For tensors with head dimensions * type_size that aren't divisible by 32, currently only the chunked_forward API will handle padding - the low-level component kernels will not pad the tensors. To-do is probably to add it to every kernel, it's just a few more lines of code
Several kernels require matrix-like tensors in MN-major layout. In this codebase:
betamust be MN-major forbf16_gdn_compute_u_w.gatemust be MN-major for chunkedcompute_u_w,seq_state_update, andcompute_O.- Use
transpose_to_mn_majorbefore invoking low-level wrappers.
Current behavior in this repository:
- Chunked helper kernels (
chunk_local_cumsumand related paths) expect gate values in FP32. - Recurrent launcher path currently passes gate as BF16 pointer.
fused_gdn_gatingmay emit BF16 gate whenA_logis BF16.
I mirrored DeepGEMM's repository structure, as I like their simple, clean JIT infra + header-only kernels, I think it makes the code readable and a lot easier to understand.
I would recommend using run_cmake.sh to build for only cpp testing, it's (slightly) faster, and easier to debug quickly. Note that currently it's still a little unpolished since I customized it for my system. If you want to work with python, use the develop.sh script to build the python bindings.
| Area | Target path | Responsibility |
|---|---|---|
| Python binding entrypoint | csrc/python_api.cpp |
Pybind module declarations and docstrings only |
| Public C++ API wrappers | csrc/apis/*.hpp + csrc/apis/*.cpp |
Tensor validation, shape adaptation, stream plumbing |
| JIT runtime core | csrc/jit/*.hpp |
Compiler, cache, runtime handles |
| JIT kernel launch wrappers | csrc/jit_kernels/impls/*.hpp |
Launch argument packing and runtime dispatch |
| JIT heuristics/config | csrc/jit_kernels/heuristics/*.hpp |
Launch configuration heuristics |
| CUDA kernels (headers only) | include/gdn_cuda/kernels/*.cuh |
Device kernel definitions |
| Common CUDA utilities | include/gdn_cuda/kernels/common/*.hpp/.cuh |
Shared device-side helpers |
- add support for a qk l2 norm that happens in gdn chunked forward, and add this into tests as well
- write guide on how to use jit and env variables
- ping pong scheduling on all kernels for heavier tma stores
- probably only needed for head_dim > 128, the current 64 x 128