44//! Adapted from the Crypto++ `chacha_simd` implementation by Jack Lloyd and
55//! Jeffrey Walton (public domain).
66
7- use crate :: { Rounds , STATE_WORDS } ;
7+ use crate :: { Rounds , STATE_WORDS , Variant } ;
88use core:: { arch:: aarch64:: * , marker:: PhantomData } ;
99
1010#[ cfg( feature = "rand_core" ) ]
11- use crate :: { ChaChaCore , Variant } ;
11+ use crate :: ChaChaCore ;
1212
1313#[ cfg( feature = "cipher" ) ]
1414use crate :: chacha:: Block ;
@@ -19,13 +19,26 @@ use cipher::{
1919 consts:: { U4 , U64 } ,
2020} ;
2121
22- struct Backend < R : Rounds > {
22+ struct Backend < R : Rounds , V : Variant > {
2323 state : [ uint32x4_t ; 4 ] ,
2424 ctrs : [ uint32x4_t ; 4 ] ,
25- _pd : PhantomData < R > ,
25+ _pd : PhantomData < ( R , V ) > ,
2626}
2727
28- impl < R : Rounds > Backend < R > {
28+ macro_rules! add_counter {
29+ ( $a: expr, $b: expr, $variant: ty) => {
30+ match size_of:: <<$variant>:: Counter >( ) {
31+ 4 => vaddq_u32( $a, $b) ,
32+ 8 => vreinterpretq_u32_u64( vaddq_u64(
33+ vreinterpretq_u64_u32( $a) ,
34+ vreinterpretq_u64_u32( $b) ,
35+ ) ) ,
36+ _ => unreachable!( ) ,
37+ }
38+ } ;
39+ }
40+
41+ impl < R : Rounds , V : Variant > Backend < R , V > {
2942 #[ inline]
3043 unsafe fn new ( state : & mut [ u32 ; STATE_WORDS ] ) -> Self {
3144 let state = [
@@ -40,7 +53,7 @@ impl<R: Rounds> Backend<R> {
4053 vld1q_u32 ( [ 3 , 0 , 0 , 0 ] . as_ptr ( ) ) ,
4154 vld1q_u32 ( [ 4 , 0 , 0 , 0 ] . as_ptr ( ) ) ,
4255 ] ;
43- Backend :: < R > {
56+ Backend :: < R , V > {
4457 state,
4558 ctrs,
4659 _pd : PhantomData ,
@@ -51,16 +64,24 @@ impl<R: Rounds> Backend<R> {
5164#[ inline]
5265#[ cfg( feature = "cipher" ) ]
5366#[ target_feature( enable = "neon" ) ]
54- pub ( crate ) unsafe fn inner < R , F > ( state : & mut [ u32 ; STATE_WORDS ] , f : F )
67+ pub ( crate ) unsafe fn inner < R , F , V > ( state : & mut [ u32 ; STATE_WORDS ] , f : F )
5568where
5669 R : Rounds ,
5770 F : StreamCipherClosure < BlockSize = U64 > ,
71+ V : Variant ,
5872{
59- let mut backend = Backend :: < R > :: new ( state) ;
73+ let mut backend = Backend :: < R , V > :: new ( state) ;
6074
6175 f. call ( & mut backend) ;
6276
63- vst1q_u32 ( state. as_mut_ptr ( ) . offset ( 12 ) , backend. state [ 3 ] ) ;
77+ match size_of :: < V :: Counter > ( ) {
78+ 4 => state[ 12 ] = vgetq_lane_u32 ( backend. state [ 3 ] , 0 ) ,
79+ 8 => vst1q_u64 (
80+ state. as_mut_ptr ( ) . offset ( 12 ) as * mut u64 ,
81+ vreinterpretq_u64_u32 ( backend. state [ 3 ] ) ,
82+ ) ,
83+ _ => unreachable ! ( ) ,
84+ }
6485}
6586
6687#[ inline]
@@ -73,19 +94,22 @@ where
7394 R : Rounds ,
7495 V : Variant ,
7596{
76- let mut backend = Backend :: < R > :: new ( & mut core. state ) ;
97+ let mut backend = Backend :: < R , V > :: new ( & mut core. state ) ;
7798
7899 backend. write_par_ks_blocks ( buffer) ;
79100
80- vst1q_u32 ( core. state . as_mut_ptr ( ) . offset ( 12 ) , backend. state [ 3 ] ) ;
101+ vst1q_u64 (
102+ core. state . as_mut_ptr ( ) . offset ( 12 ) as * mut u64 ,
103+ vreinterpretq_u64_u32 ( backend. state [ 3 ] ) ,
104+ ) ;
81105}
82106
83107#[ cfg( feature = "cipher" ) ]
84- impl < R : Rounds > BlockSizeUser for Backend < R > {
108+ impl < R : Rounds , V : Variant > BlockSizeUser for Backend < R , V > {
85109 type BlockSize = U64 ;
86110}
87111#[ cfg( feature = "cipher" ) ]
88- impl < R : Rounds > ParBlocksSizeUser for Backend < R > {
112+ impl < R : Rounds , V : Variant > ParBlocksSizeUser for Backend < R , V > {
89113 type ParBlocksSize = U4 ;
90114}
91115
@@ -97,15 +121,15 @@ macro_rules! add_assign_vec {
97121}
98122
99123#[ cfg( feature = "cipher" ) ]
100- impl < R : Rounds > StreamCipherBackend for Backend < R > {
124+ impl < R : Rounds , V : Variant > StreamCipherBackend for Backend < R , V > {
101125 #[ inline( always) ]
102126 fn gen_ks_block ( & mut self , block : & mut Block ) {
103127 let state3 = self . state [ 3 ] ;
104128 let mut par = ParBlocks :: < Self > :: default ( ) ;
105129 self . gen_par_ks_blocks ( & mut par) ;
106130 * block = par[ 0 ] ;
107131 unsafe {
108- self . state [ 3 ] = vaddq_u32 ( state3, vld1q_u32 ( [ 1 , 0 , 0 , 0 ] . as_ptr ( ) ) ) ;
132+ self . state [ 3 ] = add_counter ! ( state3, vld1q_u32( [ 1 , 0 , 0 , 0 ] . as_ptr( ) ) , V ) ;
109133 }
110134 }
111135
@@ -118,19 +142,19 @@ impl<R: Rounds> StreamCipherBackend for Backend<R> {
118142 self . state [ 0 ] ,
119143 self . state [ 1 ] ,
120144 self . state [ 2 ] ,
121- vaddq_u32 ( self . state [ 3 ] , self . ctrs [ 0 ] ) ,
145+ add_counter ! ( self . state[ 3 ] , self . ctrs[ 0 ] , V ) ,
122146 ] ,
123147 [
124148 self . state [ 0 ] ,
125149 self . state [ 1 ] ,
126150 self . state [ 2 ] ,
127- vaddq_u32 ( self . state [ 3 ] , self . ctrs [ 1 ] ) ,
151+ add_counter ! ( self . state[ 3 ] , self . ctrs[ 1 ] , V ) ,
128152 ] ,
129153 [
130154 self . state [ 0 ] ,
131155 self . state [ 1 ] ,
132156 self . state [ 2 ] ,
133- vaddq_u32 ( self . state [ 3 ] , self . ctrs [ 2 ] ) ,
157+ add_counter ! ( self . state[ 3 ] , self . ctrs[ 2 ] , V ) ,
134158 ] ,
135159 ] ;
136160
@@ -140,11 +164,16 @@ impl<R: Rounds> StreamCipherBackend for Backend<R> {
140164
141165 for block in 0 ..4 {
142166 // add state to block
143- for state_row in 0 ..4 {
167+ for state_row in 0 ..3 {
144168 add_assign_vec ! ( blocks[ block] [ state_row] , self . state[ state_row] ) ;
145169 }
146170 if block > 0 {
147- blocks[ block] [ 3 ] = vaddq_u32 ( blocks[ block] [ 3 ] , self . ctrs [ block - 1 ] ) ;
171+ add_assign_vec ! (
172+ blocks[ block] [ 3 ] ,
173+ add_counter!( self . state[ 3 ] , self . ctrs[ block - 1 ] , V )
174+ ) ;
175+ } else {
176+ add_assign_vec ! ( blocks[ block] [ 3 ] , self . state[ 3 ] ) ;
148177 }
149178 // write blocks to dest
150179 for state_row in 0 ..4 {
@@ -154,7 +183,7 @@ impl<R: Rounds> StreamCipherBackend for Backend<R> {
154183 ) ;
155184 }
156185 }
157- self . state [ 3 ] = vaddq_u32 ( self . state [ 3 ] , self . ctrs [ 3 ] ) ;
186+ self . state [ 3 ] = add_counter ! ( self . state[ 3 ] , self . ctrs[ 3 ] , V ) ;
158187 }
159188 }
160189}
@@ -180,7 +209,7 @@ macro_rules! extract {
180209 } ;
181210}
182211
183- impl < R : Rounds > Backend < R > {
212+ impl < R : Rounds , V : Variant > Backend < R , V > {
184213 #[ inline( always) ]
185214 /// Generates `num_blocks` blocks and blindly writes them to `dest_ptr`
186215 ///
@@ -197,19 +226,19 @@ impl<R: Rounds> Backend<R> {
197226 self . state [ 0 ] ,
198227 self . state [ 1 ] ,
199228 self . state [ 2 ] ,
200- vaddq_u32 ( self . state [ 3 ] , self . ctrs [ 0 ] ) ,
229+ add_counter ! ( self . state[ 3 ] , self . ctrs[ 0 ] , V ) ,
201230 ] ,
202231 [
203232 self . state [ 0 ] ,
204233 self . state [ 1 ] ,
205234 self . state [ 2 ] ,
206- vaddq_u32 ( self . state [ 3 ] , self . ctrs [ 1 ] ) ,
235+ add_counter ! ( self . state[ 3 ] , self . ctrs[ 1 ] , V ) ,
207236 ] ,
208237 [
209238 self . state [ 0 ] ,
210239 self . state [ 1 ] ,
211240 self . state [ 2 ] ,
212- vaddq_u32 ( self . state [ 3 ] , self . ctrs [ 2 ] ) ,
241+ add_counter ! ( self . state[ 3 ] , self . ctrs[ 2 ] , V ) ,
213242 ] ,
214243 ] ;
215244
@@ -220,11 +249,16 @@ impl<R: Rounds> Backend<R> {
220249 let mut dest_ptr = buffer. as_mut_ptr ( ) as * mut u8 ;
221250 for block in 0 ..4 {
222251 // add state to block
223- for state_row in 0 ..4 {
252+ for state_row in 0 ..3 {
224253 add_assign_vec ! ( blocks[ block] [ state_row] , self . state[ state_row] ) ;
225254 }
226255 if block > 0 {
227- blocks[ block] [ 3 ] = vaddq_u32 ( blocks[ block] [ 3 ] , self . ctrs [ block - 1 ] ) ;
256+ add_assign_vec ! (
257+ blocks[ block] [ 3 ] ,
258+ add_counter!( self . state[ 3 ] , self . ctrs[ block - 1 ] , V )
259+ ) ;
260+ } else {
261+ add_assign_vec ! ( blocks[ block] [ 3 ] , self . state[ 3 ] ) ;
228262 }
229263 // write blocks to buffer
230264 for state_row in 0 ..4 {
@@ -235,7 +269,7 @@ impl<R: Rounds> Backend<R> {
235269 }
236270 dest_ptr = dest_ptr. add ( 64 ) ;
237271 }
238- self . state [ 3 ] = vaddq_u32 ( self . state [ 3 ] , self . ctrs [ 3 ] ) ;
272+ self . state [ 3 ] = add_counter ! ( self . state[ 3 ] , self . ctrs[ 3 ] , V ) ;
239273 }
240274}
241275
0 commit comments