Skip to content

Optimize CUDA kernels#3600

Open
guoqingbao wants to merge 3 commits into
huggingface:mainfrom
guoqingbao:cuda_kernels
Open

Optimize CUDA kernels#3600
guoqingbao wants to merge 3 commits into
huggingface:mainfrom
guoqingbao:cuda_kernels

Conversation

@guoqingbao

Copy link
Copy Markdown
Contributor

Summary

This PR adds another round of CUDA kernel optimizations for binary ops, cast ops, reductions, and selected unary ops.

The main focus is improving hot contiguous BF16/F32 paths by optimizing cuda kernels with specialized vectorized implementations, reducing indexing overhead, and using more efficient CUDA primitives where appropriate.

Combined, these CUDA kernel optimizations deliver up to 15% end-to-end speedup in various models.

Changes

Binary Ops

Updated binary.cu and binary_op_macros.cuh.

  • Optimized BF16 add/mul:
    • badd_bf16
    • bmul_bf16
  • Replaced generic implementations with vectorized kernels for contiguous inputs.
  • Contiguous BF16 kernels now use float4 loads, processing 8 BF16 elements per load.
  • BF16 add/mul use __hadd2 and __hmul2 intrinsics.
  • Added semi-contiguous non-contiguous paths for LHS-only or RHS-only strided inputs.
  • Optimized F32 add/div/mul:
    • badd_f32
    • bdiv_f32
    • bmul_f32
  • F32 contiguous kernels now use float4 loads, processing 4 F32 elements per load.
  • Added BINARY_OP_BF16_VEC in binary_op_macros.cuh for reusable vectorized BF16 binary ops with float-promoted per-element computation.

Cast Ops

Updated cast.cu.

  • Optimized cast_bf16_f32 with vectorized contiguous loads.
  • BF16 to F32 conversion now uses float4 loads, processing 8 BF16 values per load.
  • Conversion is unrolled via __bfloat162float.
  • Optimized cast_f32_bf16 with vectorized contiguous loads.
  • F32 to BF16 conversion now uses float4 loads, processing 4 F32 values per load.
  • Conversion uses __float2bfloat16_rn.

Reduce Ops

Updated reduce.cu and mod.rs.

  • Reworked the fast_sum template.
  • Replaced shared-memory tree reduction with a warp-shuffle two-phase reduction using __shfl_xor_sync.
  • Added a contiguous fast path that skips get_strided_index.
  • Accumulation now uses float for better precision with half-precision input types.
  • Added fast_sum_bf16_vec, a specialized BF16 vectorized reduce kernel using:
    • float4 loads
    • 8 BF16 elements per load
    • warp-shuffle reduction
  • Added an explicit optimized fast_sum_f32 path using:
    • float4 loads
    • 4 F32 elements per load
    • warp-shuffle reduction
  • Replaced the previous generic FAST_OP macro path for F32 sum.
  • Added fast_sum_small kernels for small reductions where el_to_sum <= 32.
  • Added specialized small-reduction variants for:
    • BF16
    • F16
    • F32
    • F64
  • Updated Rust dispatch in mod.rs with use_small_reduce logic.
  • Sum reductions with el_to_sum <= 32 now dispatch to fast_sum_small using 256-thread blocks instead of launching one block per output element.

Unary Ops

Updated unary.cu.

  • Optimized ucopy_bf16 with vectorized contiguous copies using float4 loads.
  • Optimized usilu_bf16 with vectorized contiguous processing:
    • 8 BF16 elements per thread iteration
    • float-promoted computation
    • computes x / (1 + exp(-x))
  • Optimized usigmoid_bf16 with vectorized contiguous processing:
    • 8 BF16 elements per thread iteration
    • float-promoted computation
    • computes 1 / (1 + exp(-x))

Expected Impact

These changes should improve throughput for common CUDA workloads, especially contiguous BF16 and F32 operations.

The largest expected gains are in:

  • BF16 add/mul
  • F32 add/div/mul
  • BF16/F32 cast operations
  • Contiguous sum reductions
  • Small sum reductions with el_to_sum <= 32
  • BF16 copy, SiLU, and sigmoid

The optimizations reduce scalar memory access, avoid unnecessary stride indexing on contiguous paths, reduce shared-memory synchronization overhead in reductions, and improve utilization through vectorized loads.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant