Skip to content

Commit 5228b51

Browse files
Merge pull request #64 from quickwit-oss/faster_neon
Improve NEON instructions for BitPacker4x
2 parents 8b651a4 + 855806b commit 5228b51

3 files changed

Lines changed: 168 additions & 32 deletions

File tree

src/bitpacker4x.rs

Lines changed: 145 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -83,48 +83,165 @@ mod neon {
8383

8484
use super::BLOCK_LEN;
8585
use crate::Available;
86+
use std::arch::aarch64::{
87+
uint32x4_t, vaddq_u32, vandq_u32, vdupq_n_u32, vextq_u32, vgetq_lane_u32, vld1q_u32,
88+
vorrq_u32, vshlq_n_u32, vshrq_n_u32, vst1q_u32, vsubq_u32,
89+
};
90+
91+
pub(crate) type DataType = uint32x4_t;
92+
93+
#[inline]
94+
/// Creates a vector with all elements set to `el`.
95+
unsafe fn set1(el: i32) -> DataType {
96+
vdupq_n_u32(el as u32)
97+
}
98+
99+
#[inline]
100+
unsafe fn right_shift_32<const N: i32>(el: DataType) -> DataType {
101+
const {
102+
assert!(N >= 0);
103+
assert!(N <= 32);
104+
}
105+
106+
// We unroll here because vshrq_n_u32 only accepts constants from 1 to 32.
107+
match N {
108+
0 => el,
109+
1 => vshrq_n_u32::<1>(el),
110+
2 => vshrq_n_u32::<2>(el),
111+
3 => vshrq_n_u32::<3>(el),
112+
4 => vshrq_n_u32::<4>(el),
113+
5 => vshrq_n_u32::<5>(el),
114+
6 => vshrq_n_u32::<6>(el),
115+
7 => vshrq_n_u32::<7>(el),
116+
8 => vshrq_n_u32::<8>(el),
117+
9 => vshrq_n_u32::<9>(el),
118+
10 => vshrq_n_u32::<10>(el),
119+
11 => vshrq_n_u32::<11>(el),
120+
12 => vshrq_n_u32::<12>(el),
121+
13 => vshrq_n_u32::<13>(el),
122+
14 => vshrq_n_u32::<14>(el),
123+
15 => vshrq_n_u32::<15>(el),
124+
16 => vshrq_n_u32::<16>(el),
125+
17 => vshrq_n_u32::<17>(el),
126+
18 => vshrq_n_u32::<18>(el),
127+
19 => vshrq_n_u32::<19>(el),
128+
20 => vshrq_n_u32::<20>(el),
129+
21 => vshrq_n_u32::<21>(el),
130+
22 => vshrq_n_u32::<22>(el),
131+
23 => vshrq_n_u32::<23>(el),
132+
24 => vshrq_n_u32::<24>(el),
133+
25 => vshrq_n_u32::<25>(el),
134+
26 => vshrq_n_u32::<26>(el),
135+
27 => vshrq_n_u32::<27>(el),
136+
28 => vshrq_n_u32::<28>(el),
137+
29 => vshrq_n_u32::<29>(el),
138+
30 => vshrq_n_u32::<30>(el),
139+
31 => vshrq_n_u32::<31>(el),
140+
32 => vdupq_n_u32(0),
141+
_ => core::hint::unreachable_unchecked(),
142+
}
143+
}
144+
145+
#[inline]
146+
unsafe fn left_shift_32<const N: i32>(el: DataType) -> DataType {
147+
const {
148+
assert!(N >= 0);
149+
assert!(N <= 32);
150+
}
151+
152+
// We unroll here because vshlq_n_u32 only accepts constants from 0 to 31.
153+
match N {
154+
0 => el,
155+
1 => vshlq_n_u32::<1>(el),
156+
2 => vshlq_n_u32::<2>(el),
157+
3 => vshlq_n_u32::<3>(el),
158+
4 => vshlq_n_u32::<4>(el),
159+
5 => vshlq_n_u32::<5>(el),
160+
6 => vshlq_n_u32::<6>(el),
161+
7 => vshlq_n_u32::<7>(el),
162+
8 => vshlq_n_u32::<8>(el),
163+
9 => vshlq_n_u32::<9>(el),
164+
10 => vshlq_n_u32::<10>(el),
165+
11 => vshlq_n_u32::<11>(el),
166+
12 => vshlq_n_u32::<12>(el),
167+
13 => vshlq_n_u32::<13>(el),
168+
14 => vshlq_n_u32::<14>(el),
169+
15 => vshlq_n_u32::<15>(el),
170+
16 => vshlq_n_u32::<16>(el),
171+
17 => vshlq_n_u32::<17>(el),
172+
18 => vshlq_n_u32::<18>(el),
173+
19 => vshlq_n_u32::<19>(el),
174+
20 => vshlq_n_u32::<20>(el),
175+
21 => vshlq_n_u32::<21>(el),
176+
22 => vshlq_n_u32::<22>(el),
177+
23 => vshlq_n_u32::<23>(el),
178+
24 => vshlq_n_u32::<24>(el),
179+
25 => vshlq_n_u32::<25>(el),
180+
26 => vshlq_n_u32::<26>(el),
181+
27 => vshlq_n_u32::<27>(el),
182+
28 => vshlq_n_u32::<28>(el),
183+
29 => vshlq_n_u32::<29>(el),
184+
30 => vshlq_n_u32::<30>(el),
185+
31 => vshlq_n_u32::<31>(el),
186+
32 => vdupq_n_u32(0),
187+
_ => core::hint::unreachable_unchecked(),
188+
}
189+
}
190+
191+
use vorrq_u32 as op_or;
192+
193+
#[inline]
194+
unsafe fn op_and(left: DataType, right: DataType) -> DataType {
195+
vandq_u32(left, right)
196+
}
197+
198+
#[inline]
199+
unsafe fn load_unaligned(addr: *const DataType) -> DataType {
200+
vld1q_u32(addr.cast::<u32>())
201+
}
202+
203+
#[inline]
204+
unsafe fn store_unaligned(addr: *mut DataType, data: DataType) {
205+
vst1q_u32(addr.cast::<u32>(), data);
206+
}
207+
208+
#[inline]
209+
/// Collapses the vector by performing a bitwise OR across all lanes
210+
unsafe fn or_collapse_to_u32(acc: DataType) -> u32 {
211+
vgetq_lane_u32(acc, 0)
212+
| vgetq_lane_u32(acc, 1)
213+
| vgetq_lane_u32(acc, 2)
214+
| vgetq_lane_u32(acc, 3)
215+
}
86216

87-
use super::scalar::add;
88-
use super::scalar::left_shift_32;
89-
use super::scalar::load_unaligned;
90-
use super::scalar::op_and;
91-
use super::scalar::op_or;
92-
use super::scalar::or_collapse_to_u32;
93-
use super::scalar::right_shift_32;
94-
use super::scalar::set1;
95-
use super::scalar::store_unaligned;
96-
use super::scalar::sub;
97-
use super::scalar::DataType;
98-
use std::arch::aarch64::{vaddq_u32, vdupq_n_u32, vextq_u32, vld1q_u32, vst1q_u32, vsubq_u32};
99-
100-
#[target_feature(enable = "neon")]
101217
unsafe fn compute_delta(curr: DataType, prev: DataType) -> DataType {
102-
let c = vld1q_u32(curr.as_ptr());
103-
let p = vld1q_u32(prev.as_ptr());
104-
let mut r = set1(0);
105-
vst1q_u32(r.as_mut_ptr(), vsubq_u32(c, vextq_u32(p, c, 3)));
106-
r
218+
// Build a vector with [prev[3], curr[0], curr[1], curr[2]]
219+
let prev_shifted = vextq_u32(prev, curr, 3);
220+
vsubq_u32(curr, prev_shifted)
107221
}
108222

109-
#[target_feature(enable = "neon")]
110223
#[allow(non_snake_case)]
111224
#[inline]
112225
unsafe fn integrate_delta(prev: DataType, delta: DataType) -> DataType {
113-
let base = vdupq_n_u32(prev[3]);
226+
let base = vdupq_n_u32(vgetq_lane_u32(prev, 3));
114227
let zero = vdupq_n_u32(0);
115-
let a__b__c__d_ = vld1q_u32(delta.as_ptr());
228+
let a__b__c__d_ = delta;
116229
let ______a__b_ = vextq_u32(zero, a__b__c__d_, 2);
117230
let a__b__ca_db = vaddq_u32(______a__b_, a__b__c__d_);
118231
let ___a__b__ca = vextq_u32(zero, a__b__ca_db, 3);
119232
let a_ab_abc_abcd = vaddq_u32(___a__b__ca, a__b__ca_db);
120-
let mut r = set1(0);
121-
vst1q_u32(r.as_mut_ptr(), vaddq_u32(base, a_ab_abc_abcd));
122-
r
233+
vaddq_u32(base, a_ab_abc_abcd)
123234
}
124235

125-
// TODO trinity-1686a: I believe add/sub are easy enough for the compiler to optimize on its
126-
// own, and suspect hand-rolled impl would force (un)loading registers and make things slower
127-
// overall
236+
#[inline]
237+
unsafe fn add(left: DataType, right: DataType) -> DataType {
238+
vaddq_u32(left, right)
239+
}
240+
241+
#[inline]
242+
unsafe fn sub(left: DataType, right: DataType) -> DataType {
243+
vsubq_u32(left, right)
244+
}
128245

129246
declare_bitpacker!(target_feature(enable = "neon"));
130247

src/bitpacking_bench.rs

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
use criterion::{Bencher, Criterion, criterion_group, criterion_main};
1+
use criterion::{criterion_group, criterion_main, Bencher, Criterion};
2+
use std::time::Duration;
23

34
use bitpacking::{BitPacker, BitPacker1x, BitPacker4x, BitPacker8x};
45
use criterion::Benchmark;
56
use criterion::Throughput;
67

78
const NUM_BLOCKS: usize = 10;
9+
const SAMPLE_SIZE: usize = 10;
10+
const WARM_UP_TIME: Duration = Duration::from_millis(50);
811

912
fn integrate_data(initial: u32, data: &mut [u32]) {
1013
let mut cumul = initial;
@@ -245,6 +248,8 @@ fn criterion_benchmark_bitpacker<TBitPacker: BitPacker + 'static>(
245248
Benchmark::new(format!("decompress-{num_bit}").as_str(), move |b| {
246249
bench_decompress_util::<TBitPacker>(bitpacker, b, &num_bits[..]);
247250
})
251+
.warm_up_time(WARM_UP_TIME)
252+
.sample_size(SAMPLE_SIZE)
248253
.throughput(Throughput::Elements(
249254
(NUM_BLOCKS * TBitPacker::BLOCK_LEN) as u64,
250255
)),
@@ -254,6 +259,8 @@ fn criterion_benchmark_bitpacker<TBitPacker: BitPacker + 'static>(
254259
Benchmark::new(format!("decompress-delta-{num_bit}").as_str(), move |b| {
255260
bench_decompress_delta_util::<TBitPacker>(bitpacker, b, &num_bits[..]);
256261
})
262+
.warm_up_time(WARM_UP_TIME)
263+
.sample_size(SAMPLE_SIZE)
257264
.throughput(Throughput::Elements(
258265
(NUM_BLOCKS * TBitPacker::BLOCK_LEN) as u64,
259266
)),
@@ -266,6 +273,8 @@ fn criterion_benchmark_bitpacker<TBitPacker: BitPacker + 'static>(
266273
bench_decompress_strict_delta_util::<TBitPacker>(bitpacker, b, &num_bits[..]);
267274
},
268275
)
276+
.warm_up_time(WARM_UP_TIME)
277+
.sample_size(SAMPLE_SIZE)
269278
.throughput(Throughput::Elements(
270279
(NUM_BLOCKS * TBitPacker::BLOCK_LEN) as u64,
271280
)),
@@ -275,6 +284,8 @@ fn criterion_benchmark_bitpacker<TBitPacker: BitPacker + 'static>(
275284
Benchmark::new(format!("compress-{num_bit}").as_str(), move |b| {
276285
bench_compress_util::<TBitPacker>(bitpacker, b, &num_bits[..]);
277286
})
287+
.warm_up_time(WARM_UP_TIME)
288+
.sample_size(SAMPLE_SIZE)
278289
.throughput(Throughput::Elements(
279290
(NUM_BLOCKS * TBitPacker::BLOCK_LEN) as u64,
280291
)),
@@ -284,6 +295,8 @@ fn criterion_benchmark_bitpacker<TBitPacker: BitPacker + 'static>(
284295
Benchmark::new(format!("compress-delta-{num_bit}").as_str(), move |b| {
285296
bench_compress_delta_util::<TBitPacker>(bitpacker, b, &num_bits[..]);
286297
})
298+
.warm_up_time(WARM_UP_TIME)
299+
.sample_size(SAMPLE_SIZE)
287300
.throughput(Throughput::Elements(
288301
(NUM_BLOCKS * TBitPacker::BLOCK_LEN) as u64,
289302
)),
@@ -296,6 +309,8 @@ fn criterion_benchmark_bitpacker<TBitPacker: BitPacker + 'static>(
296309
bench_compress_strict_delta_util::<TBitPacker>(bitpacker, b, &num_bits[..]);
297310
},
298311
)
312+
.warm_up_time(WARM_UP_TIME)
313+
.sample_size(SAMPLE_SIZE)
299314
.throughput(Throughput::Elements(
300315
(NUM_BLOCKS * TBitPacker::BLOCK_LEN) as u64,
301316
)),
@@ -309,5 +324,9 @@ fn criterion_benchmark(criterion: &mut Criterion) {
309324
criterion_benchmark_bitpacker("BitPacker8x", BitPacker8x::new(), criterion);
310325
}
311326

312-
criterion_group!(benches, criterion_benchmark);
327+
criterion_group! {
328+
name = benches;
329+
config = Criterion::default().warm_up_time(Duration::from_millis(50));
330+
targets = criterion_benchmark
331+
}
313332
criterion_main!(benches);

src/tests.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
use super::most_significant_bit;
2+
use super::UnsafeBitPacker;
13
use rand::distributions::{Distribution as _, Uniform};
24
use rand::rngs::StdRng;
35
use rand::SeedableRng as _;
4-
use super::most_significant_bit;
5-
use super::UnsafeBitPacker;
66

77
pub fn generate_array(n: usize, max_num_bits: u8) -> Vec<u32> {
88
assert!(max_num_bits <= 32u8);

0 commit comments

Comments
 (0)