@@ -62,56 +62,45 @@ __device__ __forceinline__ uint8_t reduce_init_highest<uint8_t>() {
6262// but also expect a f32 output so that this can be used for normalization e.g.
6363// in softmax.
6464
65- // Optimized reduce sum: contiguous fast path with vectorized loads + warp shuffle,
66- // falls back to strided path for non-contiguous data.
65+ // Fast reduce sum kernel, this assumes that the dimensions to loop over are at
66+ // the end, each block is responsible for populating one value in the output
67+ // array. There are at most 1024 threads per block.
6768template <typename T>
6869__device__ void
6970fast_sum (const size_t src_numel, const size_t el_to_sum_per_block,
7071 const size_t num_dims, const size_t *info, const T *src, T *dst) {
7172 const size_t *dims = info;
7273 const size_t *strides = info + num_dims;
7374
74- __shared__ float shr[BLOCK_SIZE ];
75+ __shared__ T shr[BLOCK_SIZE ];
7576 size_t tid = threadIdx .x ;
7677 size_t dst_id = blockIdx .x ;
7778
79+ shr[tid] = 0 ;
80+ // Elements summed in this block range from dst_id * el_to_sum_per_block
81+ // to (dst_id + 1) * el_to_sum_per_block.
7882 size_t start_idx = dst_id * el_to_sum_per_block;
7983 size_t stop_idx = min (start_idx + el_to_sum_per_block, src_numel);
84+ size_t idx = start_idx + tid;
8085
81- float local_sum = 0 .0f ;
82-
83- if (is_contiguous (num_dims, dims, strides)) {
84- size_t idx = start_idx + tid;
85- while (idx < stop_idx) {
86- local_sum += static_cast <float >(src[idx]);
87- idx += blockDim .x ;
88- }
89- } else {
90- size_t idx = start_idx + tid;
91- while (idx < stop_idx) {
92- size_t strided_i = get_strided_index (idx, num_dims, dims, strides);
93- local_sum += static_cast <float >(src[strided_i]);
94- idx += blockDim .x ;
95- }
86+ while (idx < stop_idx) {
87+ // TODO: Fast version for the contiguous case.
88+ size_t strided_i = get_strided_index (idx, num_dims, dims, strides);
89+ shr[tid] += src[strided_i];
90+ idx += blockDim .x ;
9691 }
9792
98- // Warp-level reduction first
99- for (int offset = WARP_SIZE / 2 ; offset > 0 ; offset >>= 1 )
100- local_sum += __shfl_xor_sync (0xffffffff , local_sum, offset);
101-
102- int warp_id = tid / WARP_SIZE ;
103- int lane_id = tid % WARP_SIZE ;
104- if (lane_id == 0 ) shr[warp_id] = local_sum;
105- __syncthreads ();
106-
107- // Final reduction across warps
108- int num_warps = blockDim .x / WARP_SIZE ;
109- if (tid < WARP_SIZE ) {
110- local_sum = (tid < num_warps) ? shr[tid] : 0 .0f ;
111- for (int offset = WARP_SIZE / 2 ; offset > 0 ; offset >>= 1 )
112- local_sum += __shfl_xor_sync (0xffffffff , local_sum, offset);
113- if (tid == 0 ) dst[dst_id] = static_cast <T>(local_sum);
93+ // Parallel reduction, see the slides:
94+ // https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf
95+ // https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
96+ for (int s = blockDim .x / 2 ; s > 0 ; s >>= 1 ) {
97+ __syncthreads ();
98+ if (tid < s)
99+ shr[tid] += shr[tid + s];
114100 }
101+
102+ if (tid == 0 )
103+ dst[dst_id] = shr[0 ];
115104}
116105
117106// Specialized vectorized fast_sum for bf16: 8 elements per float4 load
@@ -747,22 +736,20 @@ fast_sum_small_impl(const size_t src_numel, const size_t el_to_sum_per_block,
747736 size_t gid = blockIdx .x * blockDim .x + threadIdx .x ;
748737 if (gid >= dst_el) return ;
749738
750- float sum = 0 . 0f ;
739+ T sum = 0 ;
751740
752741 if (is_contiguous (num_dims, dims, strides)) {
753742 size_t start = gid * el_to_sum_per_block;
754743 for (size_t i = 0 ; i < el_to_sum_per_block; ++i) {
755- sum += static_cast < float >( src[start + i]) ;
744+ sum += src[start + i];
756745 }
757746 } else {
758- size_t base_linear = gid * el_to_sum_per_block;
759- size_t base_strided = get_strided_index (base_linear, num_dims, dims, strides);
760- size_t sum_stride = strides[num_dims - 1 ];
761747 for (size_t i = 0 ; i < el_to_sum_per_block; ++i) {
762- sum += static_cast <float >(src[base_strided + i * sum_stride]);
748+ size_t strided_i = get_strided_index (gid * el_to_sum_per_block + i, num_dims, dims, strides);
749+ sum += src[strided_i];
763750 }
764751 }
765- dst[gid] = static_cast <T>( sum) ;
752+ dst[gid] = sum;
766753}
767754
768755#if __CUDA_ARCH__ >= 800
@@ -784,15 +771,9 @@ extern "C" __global__ void fast_sum_small_bf16(
784771 sum += __bfloat162float (src[start + i]);
785772 }
786773 } else {
787- // Compute base address once via get_strided_index, then use the stride
788- // of the innermost (sum) dimension for subsequent elements.
789- // This avoids expensive integer division per element.
790- size_t base_linear = gid * el_to_sum_per_block;
791- size_t base_strided = get_strided_index (base_linear, num_dims, dims, strides);
792- size_t sum_stride = strides[num_dims - 1 ];
793-
794774 for (size_t i = 0 ; i < el_to_sum_per_block; ++i) {
795- sum += __bfloat162float (src[base_strided + i * sum_stride]);
775+ size_t strided_i = get_strided_index (gid * el_to_sum_per_block + i, num_dims, dims, strides);
776+ sum += __bfloat162float (src[strided_i]);
796777 }
797778 }
798779
@@ -814,6 +795,27 @@ extern "C" __global__ void fast_sum_small_f64(
814795 fast_sum_small_impl (src_numel, el_to_sum_per_block, num_dims, info, src, dst);
815796}
816797
798+ extern " C" __global__ void fast_sum_small_u32 (
799+ const size_t src_numel, const size_t el_to_sum_per_block,
800+ const size_t num_dims, const size_t *info, const uint32_t *src,
801+ uint32_t *dst) {
802+ fast_sum_small_impl (src_numel, el_to_sum_per_block, num_dims, info, src, dst);
803+ }
804+
805+ extern " C" __global__ void fast_sum_small_i64 (
806+ const size_t src_numel, const size_t el_to_sum_per_block,
807+ const size_t num_dims, const size_t *info, const int64_t *src,
808+ int64_t *dst) {
809+ fast_sum_small_impl (src_numel, el_to_sum_per_block, num_dims, info, src, dst);
810+ }
811+
812+ extern " C" __global__ void fast_sum_small_u8 (
813+ const size_t src_numel, const size_t el_to_sum_per_block,
814+ const size_t num_dims, const size_t *info, const uint8_t *src,
815+ uint8_t *dst) {
816+ fast_sum_small_impl (src_numel, el_to_sum_per_block, num_dims, info, src, dst);
817+ }
818+
817819#if __CUDA_ARCH__ >= 530
818820extern " C" __global__ void fast_sum_small_f16 (
819821 const size_t src_numel, const size_t el_to_sum_per_block,
0 commit comments