@@ -83,48 +83,165 @@ mod neon {
8383
8484 use super :: BLOCK_LEN ;
8585 use crate :: Available ;
86+ use std:: arch:: aarch64:: {
87+ uint32x4_t, vaddq_u32, vandq_u32, vdupq_n_u32, vextq_u32, vgetq_lane_u32, vld1q_u32,
88+ vorrq_u32, vshlq_n_u32, vshrq_n_u32, vst1q_u32, vsubq_u32,
89+ } ;
90+
91+ pub ( crate ) type DataType = uint32x4_t ;
92+
93+ #[ inline]
94+ /// Creates a vector with all elements set to `el`.
95+ unsafe fn set1 ( el : i32 ) -> DataType {
96+ vdupq_n_u32 ( el as u32 )
97+ }
98+
99+ #[ inline]
100+ unsafe fn right_shift_32 < const N : i32 > ( el : DataType ) -> DataType {
101+ const {
102+ assert ! ( N >= 0 ) ;
103+ assert ! ( N <= 32 ) ;
104+ }
105+
106+ // We unroll here because vshrq_n_u32 only accepts constants from 1 to 32.
107+ match N {
108+ 0 => el,
109+ 1 => vshrq_n_u32 :: < 1 > ( el) ,
110+ 2 => vshrq_n_u32 :: < 2 > ( el) ,
111+ 3 => vshrq_n_u32 :: < 3 > ( el) ,
112+ 4 => vshrq_n_u32 :: < 4 > ( el) ,
113+ 5 => vshrq_n_u32 :: < 5 > ( el) ,
114+ 6 => vshrq_n_u32 :: < 6 > ( el) ,
115+ 7 => vshrq_n_u32 :: < 7 > ( el) ,
116+ 8 => vshrq_n_u32 :: < 8 > ( el) ,
117+ 9 => vshrq_n_u32 :: < 9 > ( el) ,
118+ 10 => vshrq_n_u32 :: < 10 > ( el) ,
119+ 11 => vshrq_n_u32 :: < 11 > ( el) ,
120+ 12 => vshrq_n_u32 :: < 12 > ( el) ,
121+ 13 => vshrq_n_u32 :: < 13 > ( el) ,
122+ 14 => vshrq_n_u32 :: < 14 > ( el) ,
123+ 15 => vshrq_n_u32 :: < 15 > ( el) ,
124+ 16 => vshrq_n_u32 :: < 16 > ( el) ,
125+ 17 => vshrq_n_u32 :: < 17 > ( el) ,
126+ 18 => vshrq_n_u32 :: < 18 > ( el) ,
127+ 19 => vshrq_n_u32 :: < 19 > ( el) ,
128+ 20 => vshrq_n_u32 :: < 20 > ( el) ,
129+ 21 => vshrq_n_u32 :: < 21 > ( el) ,
130+ 22 => vshrq_n_u32 :: < 22 > ( el) ,
131+ 23 => vshrq_n_u32 :: < 23 > ( el) ,
132+ 24 => vshrq_n_u32 :: < 24 > ( el) ,
133+ 25 => vshrq_n_u32 :: < 25 > ( el) ,
134+ 26 => vshrq_n_u32 :: < 26 > ( el) ,
135+ 27 => vshrq_n_u32 :: < 27 > ( el) ,
136+ 28 => vshrq_n_u32 :: < 28 > ( el) ,
137+ 29 => vshrq_n_u32 :: < 29 > ( el) ,
138+ 30 => vshrq_n_u32 :: < 30 > ( el) ,
139+ 31 => vshrq_n_u32 :: < 31 > ( el) ,
140+ 32 => vdupq_n_u32 ( 0 ) ,
141+ _ => core:: hint:: unreachable_unchecked ( ) ,
142+ }
143+ }
144+
145+ #[ inline]
146+ unsafe fn left_shift_32 < const N : i32 > ( el : DataType ) -> DataType {
147+ const {
148+ assert ! ( N >= 0 ) ;
149+ assert ! ( N <= 32 ) ;
150+ }
151+
152+ // We unroll here because vshlq_n_u32 only accepts constants from 0 to 31.
153+ match N {
154+ 0 => el,
155+ 1 => vshlq_n_u32 :: < 1 > ( el) ,
156+ 2 => vshlq_n_u32 :: < 2 > ( el) ,
157+ 3 => vshlq_n_u32 :: < 3 > ( el) ,
158+ 4 => vshlq_n_u32 :: < 4 > ( el) ,
159+ 5 => vshlq_n_u32 :: < 5 > ( el) ,
160+ 6 => vshlq_n_u32 :: < 6 > ( el) ,
161+ 7 => vshlq_n_u32 :: < 7 > ( el) ,
162+ 8 => vshlq_n_u32 :: < 8 > ( el) ,
163+ 9 => vshlq_n_u32 :: < 9 > ( el) ,
164+ 10 => vshlq_n_u32 :: < 10 > ( el) ,
165+ 11 => vshlq_n_u32 :: < 11 > ( el) ,
166+ 12 => vshlq_n_u32 :: < 12 > ( el) ,
167+ 13 => vshlq_n_u32 :: < 13 > ( el) ,
168+ 14 => vshlq_n_u32 :: < 14 > ( el) ,
169+ 15 => vshlq_n_u32 :: < 15 > ( el) ,
170+ 16 => vshlq_n_u32 :: < 16 > ( el) ,
171+ 17 => vshlq_n_u32 :: < 17 > ( el) ,
172+ 18 => vshlq_n_u32 :: < 18 > ( el) ,
173+ 19 => vshlq_n_u32 :: < 19 > ( el) ,
174+ 20 => vshlq_n_u32 :: < 20 > ( el) ,
175+ 21 => vshlq_n_u32 :: < 21 > ( el) ,
176+ 22 => vshlq_n_u32 :: < 22 > ( el) ,
177+ 23 => vshlq_n_u32 :: < 23 > ( el) ,
178+ 24 => vshlq_n_u32 :: < 24 > ( el) ,
179+ 25 => vshlq_n_u32 :: < 25 > ( el) ,
180+ 26 => vshlq_n_u32 :: < 26 > ( el) ,
181+ 27 => vshlq_n_u32 :: < 27 > ( el) ,
182+ 28 => vshlq_n_u32 :: < 28 > ( el) ,
183+ 29 => vshlq_n_u32 :: < 29 > ( el) ,
184+ 30 => vshlq_n_u32 :: < 30 > ( el) ,
185+ 31 => vshlq_n_u32 :: < 31 > ( el) ,
186+ 32 => vdupq_n_u32 ( 0 ) ,
187+ _ => core:: hint:: unreachable_unchecked ( ) ,
188+ }
189+ }
190+
191+ use vorrq_u32 as op_or;
192+
193+ #[ inline]
194+ unsafe fn op_and ( left : DataType , right : DataType ) -> DataType {
195+ vandq_u32 ( left, right)
196+ }
197+
198+ #[ inline]
199+ unsafe fn load_unaligned ( addr : * const DataType ) -> DataType {
200+ vld1q_u32 ( addr. cast :: < u32 > ( ) )
201+ }
202+
203+ #[ inline]
204+ unsafe fn store_unaligned ( addr : * mut DataType , data : DataType ) {
205+ vst1q_u32 ( addr. cast :: < u32 > ( ) , data) ;
206+ }
207+
208+ #[ inline]
209+ /// Collapses the vector by performing a bitwise OR across all lanes
210+ unsafe fn or_collapse_to_u32 ( acc : DataType ) -> u32 {
211+ vgetq_lane_u32 ( acc, 0 )
212+ | vgetq_lane_u32 ( acc, 1 )
213+ | vgetq_lane_u32 ( acc, 2 )
214+ | vgetq_lane_u32 ( acc, 3 )
215+ }
86216
87- use super :: scalar:: add;
88- use super :: scalar:: left_shift_32;
89- use super :: scalar:: load_unaligned;
90- use super :: scalar:: op_and;
91- use super :: scalar:: op_or;
92- use super :: scalar:: or_collapse_to_u32;
93- use super :: scalar:: right_shift_32;
94- use super :: scalar:: set1;
95- use super :: scalar:: store_unaligned;
96- use super :: scalar:: sub;
97- use super :: scalar:: DataType ;
98- use std:: arch:: aarch64:: { vaddq_u32, vdupq_n_u32, vextq_u32, vld1q_u32, vst1q_u32, vsubq_u32} ;
99-
100- #[ target_feature( enable = "neon" ) ]
101217 unsafe fn compute_delta ( curr : DataType , prev : DataType ) -> DataType {
102- let c = vld1q_u32 ( curr. as_ptr ( ) ) ;
103- let p = vld1q_u32 ( prev. as_ptr ( ) ) ;
104- let mut r = set1 ( 0 ) ;
105- vst1q_u32 ( r. as_mut_ptr ( ) , vsubq_u32 ( c, vextq_u32 ( p, c, 3 ) ) ) ;
106- r
218+ // Build a vector with [prev[3], curr[0], curr[1], curr[2]]
219+ let prev_shifted = vextq_u32 ( prev, curr, 3 ) ;
220+ vsubq_u32 ( curr, prev_shifted)
107221 }
108222
109- #[ target_feature( enable = "neon" ) ]
110223 #[ allow( non_snake_case) ]
111224 #[ inline]
112225 unsafe fn integrate_delta ( prev : DataType , delta : DataType ) -> DataType {
113- let base = vdupq_n_u32 ( prev[ 3 ] ) ;
226+ let base = vdupq_n_u32 ( vgetq_lane_u32 ( prev, 3 ) ) ;
114227 let zero = vdupq_n_u32 ( 0 ) ;
115- let a__b__c__d_ = vld1q_u32 ( delta. as_ptr ( ) ) ;
228+ let a__b__c__d_ = delta;
116229 let ______a__b_ = vextq_u32 ( zero, a__b__c__d_, 2 ) ;
117230 let a__b__ca_db = vaddq_u32 ( ______a__b_, a__b__c__d_) ;
118231 let ___a__b__ca = vextq_u32 ( zero, a__b__ca_db, 3 ) ;
119232 let a_ab_abc_abcd = vaddq_u32 ( ___a__b__ca, a__b__ca_db) ;
120- let mut r = set1 ( 0 ) ;
121- vst1q_u32 ( r. as_mut_ptr ( ) , vaddq_u32 ( base, a_ab_abc_abcd) ) ;
122- r
233+ vaddq_u32 ( base, a_ab_abc_abcd)
123234 }
124235
125- // TODO trinity-1686a: I believe add/sub are easy enough for the compiler to optimize on its
126- // own, and suspect hand-rolled impl would force (un)loading registers and make things slower
127- // overall
236+ #[ inline]
237+ unsafe fn add ( left : DataType , right : DataType ) -> DataType {
238+ vaddq_u32 ( left, right)
239+ }
240+
241+ #[ inline]
242+ unsafe fn sub ( left : DataType , right : DataType ) -> DataType {
243+ vsubq_u32 ( left, right)
244+ }
128245
129246 declare_bitpacker ! ( target_feature( enable = "neon" ) ) ;
130247
0 commit comments