Skip to content

Commit ce68a51

Browse files
committed
Preserve original fast_sum for u8, u32, i64 and f64 dtypes
1 parent 6d95f80 commit ce68a51

3 files changed

Lines changed: 70 additions & 49 deletions

File tree

candle-core/tests/tensor_tests.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,24 @@ fn sum(device: &Device) -> Result<()> {
487487
]]
488488
);
489489
}
490+
491+
let data = &[[1u8, 2, 3], [4, 5, 6]];
492+
let tensor = Tensor::new(data, device)?;
493+
assert_eq!(tensor.sum_keepdim(1)?.to_vec2::<u8>()?, &[[6], [15]]);
494+
let data = &[[1i64, 2, 3], [4, 5, 6]];
495+
let tensor = Tensor::new(data, device)?;
496+
assert_eq!(tensor.sum_keepdim(1)?.to_vec2::<i64>()?, &[[6], [15]]);
497+
498+
let mut data = vec![16_777_217u32];
499+
data.extend([1u32; 32]);
500+
let tensor = Tensor::new(data.as_slice(), device)?;
501+
assert_eq!(tensor.sum_keepdim(0)?.to_vec1::<u32>()?, &[16_777_249]);
502+
if !device.is_metal() {
503+
let mut data = vec![16_777_217f64];
504+
data.extend([1f64; 32]);
505+
let tensor = Tensor::new(data.as_slice(), device)?;
506+
assert_eq!(tensor.sum_keepdim(0)?.to_vec1::<f64>()?, &[16_777_249.]);
507+
}
490508
Ok(())
491509
}
492510

candle-kernels/build.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::path::PathBuf;
44

55
fn main() -> Result<()> {
66
println!("cargo::rerun-if-changed=build.rs");
7+
println!("cargo::rerun-if-changed=src");
78
println!("cargo::rerun-if-changed=src/compatibility.cuh");
89
println!("cargo::rerun-if-changed=src/cuda_utils.cuh");
910
println!("cargo::rerun-if-changed=src/binary_op_macros.cuh");

candle-kernels/src/reduce.cu

Lines changed: 51 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -62,56 +62,45 @@ __device__ __forceinline__ uint8_t reduce_init_highest<uint8_t>() {
6262
// but also expect a f32 output so that this can be used for normalization e.g.
6363
// in softmax.
6464

65-
// Optimized reduce sum: contiguous fast path with vectorized loads + warp shuffle,
66-
// falls back to strided path for non-contiguous data.
65+
// Fast reduce sum kernel, this assumes that the dimensions to loop over are at
66+
// the end, each block is responsible for populating one value in the output
67+
// array. There are at most 1024 threads per block.
6768
template <typename T>
6869
__device__ void
6970
fast_sum(const size_t src_numel, const size_t el_to_sum_per_block,
7071
const size_t num_dims, const size_t *info, const T *src, T *dst) {
7172
const size_t *dims = info;
7273
const size_t *strides = info + num_dims;
7374

74-
__shared__ float shr[BLOCK_SIZE];
75+
__shared__ T shr[BLOCK_SIZE];
7576
size_t tid = threadIdx.x;
7677
size_t dst_id = blockIdx.x;
7778

79+
shr[tid] = 0;
80+
// Elements summed in this block range from dst_id * el_to_sum_per_block
81+
// to (dst_id + 1) * el_to_sum_per_block.
7882
size_t start_idx = dst_id * el_to_sum_per_block;
7983
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
84+
size_t idx = start_idx + tid;
8085

81-
float local_sum = 0.0f;
82-
83-
if (is_contiguous(num_dims, dims, strides)) {
84-
size_t idx = start_idx + tid;
85-
while (idx < stop_idx) {
86-
local_sum += static_cast<float>(src[idx]);
87-
idx += blockDim.x;
88-
}
89-
} else {
90-
size_t idx = start_idx + tid;
91-
while (idx < stop_idx) {
92-
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
93-
local_sum += static_cast<float>(src[strided_i]);
94-
idx += blockDim.x;
95-
}
86+
while (idx < stop_idx) {
87+
// TODO: Fast version for the contiguous case.
88+
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
89+
shr[tid] += src[strided_i];
90+
idx += blockDim.x;
9691
}
9792

98-
// Warp-level reduction first
99-
for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1)
100-
local_sum += __shfl_xor_sync(0xffffffff, local_sum, offset);
101-
102-
int warp_id = tid / WARP_SIZE;
103-
int lane_id = tid % WARP_SIZE;
104-
if (lane_id == 0) shr[warp_id] = local_sum;
105-
__syncthreads();
106-
107-
// Final reduction across warps
108-
int num_warps = blockDim.x / WARP_SIZE;
109-
if (tid < WARP_SIZE) {
110-
local_sum = (tid < num_warps) ? shr[tid] : 0.0f;
111-
for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1)
112-
local_sum += __shfl_xor_sync(0xffffffff, local_sum, offset);
113-
if (tid == 0) dst[dst_id] = static_cast<T>(local_sum);
93+
// Parallel reduction, see the slides:
94+
// https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf
95+
// https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce
96+
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
97+
__syncthreads();
98+
if (tid < s)
99+
shr[tid] += shr[tid + s];
114100
}
101+
102+
if (tid == 0)
103+
dst[dst_id] = shr[0];
115104
}
116105

117106
// Specialized vectorized fast_sum for bf16: 8 elements per float4 load
@@ -747,22 +736,20 @@ fast_sum_small_impl(const size_t src_numel, const size_t el_to_sum_per_block,
747736
size_t gid = blockIdx.x * blockDim.x + threadIdx.x;
748737
if (gid >= dst_el) return;
749738

750-
float sum = 0.0f;
739+
T sum = 0;
751740

752741
if (is_contiguous(num_dims, dims, strides)) {
753742
size_t start = gid * el_to_sum_per_block;
754743
for (size_t i = 0; i < el_to_sum_per_block; ++i) {
755-
sum += static_cast<float>(src[start + i]);
744+
sum += src[start + i];
756745
}
757746
} else {
758-
size_t base_linear = gid * el_to_sum_per_block;
759-
size_t base_strided = get_strided_index(base_linear, num_dims, dims, strides);
760-
size_t sum_stride = strides[num_dims - 1];
761747
for (size_t i = 0; i < el_to_sum_per_block; ++i) {
762-
sum += static_cast<float>(src[base_strided + i * sum_stride]);
748+
size_t strided_i = get_strided_index(gid * el_to_sum_per_block + i, num_dims, dims, strides);
749+
sum += src[strided_i];
763750
}
764751
}
765-
dst[gid] = static_cast<T>(sum);
752+
dst[gid] = sum;
766753
}
767754

768755
#if __CUDA_ARCH__ >= 800
@@ -784,15 +771,9 @@ extern "C" __global__ void fast_sum_small_bf16(
784771
sum += __bfloat162float(src[start + i]);
785772
}
786773
} else {
787-
// Compute base address once via get_strided_index, then use the stride
788-
// of the innermost (sum) dimension for subsequent elements.
789-
// This avoids expensive integer division per element.
790-
size_t base_linear = gid * el_to_sum_per_block;
791-
size_t base_strided = get_strided_index(base_linear, num_dims, dims, strides);
792-
size_t sum_stride = strides[num_dims - 1];
793-
794774
for (size_t i = 0; i < el_to_sum_per_block; ++i) {
795-
sum += __bfloat162float(src[base_strided + i * sum_stride]);
775+
size_t strided_i = get_strided_index(gid * el_to_sum_per_block + i, num_dims, dims, strides);
776+
sum += __bfloat162float(src[strided_i]);
796777
}
797778
}
798779

@@ -814,6 +795,27 @@ extern "C" __global__ void fast_sum_small_f64(
814795
fast_sum_small_impl(src_numel, el_to_sum_per_block, num_dims, info, src, dst);
815796
}
816797

798+
extern "C" __global__ void fast_sum_small_u32(
799+
const size_t src_numel, const size_t el_to_sum_per_block,
800+
const size_t num_dims, const size_t *info, const uint32_t *src,
801+
uint32_t *dst) {
802+
fast_sum_small_impl(src_numel, el_to_sum_per_block, num_dims, info, src, dst);
803+
}
804+
805+
extern "C" __global__ void fast_sum_small_i64(
806+
const size_t src_numel, const size_t el_to_sum_per_block,
807+
const size_t num_dims, const size_t *info, const int64_t *src,
808+
int64_t *dst) {
809+
fast_sum_small_impl(src_numel, el_to_sum_per_block, num_dims, info, src, dst);
810+
}
811+
812+
extern "C" __global__ void fast_sum_small_u8(
813+
const size_t src_numel, const size_t el_to_sum_per_block,
814+
const size_t num_dims, const size_t *info, const uint8_t *src,
815+
uint8_t *dst) {
816+
fast_sum_small_impl(src_numel, el_to_sum_per_block, num_dims, info, src, dst);
817+
}
818+
817819
#if __CUDA_ARCH__ >= 530
818820
extern "C" __global__ void fast_sum_small_f16(
819821
const size_t src_numel, const size_t el_to_sum_per_block,

0 commit comments

Comments
 (0)