-
Notifications
You must be signed in to change notification settings - Fork 61
[WIP] 64-bit counters for legacy and XChaCha variants, fix looping counter regressions #399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
5674889
6305647
ee9819c
ebbbbec
dd1e10d
4c0131c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
#![allow(unsafe_op_in_unsafe_fn)] | ||
use crate::Rounds; | ||
use crate::{Rounds,Variant}; | ||
|
||
#[cfg(feature = "rng")] | ||
use crate::{ChaChaCore, Variant}; | ||
use crate::{ChaChaCore}; | ||
|
||
#[cfg(feature = "cipher")] | ||
use crate::{chacha::Block, STATE_WORDS}; | ||
|
@@ -23,49 +23,77 @@ const PAR_BLOCKS: usize = 4; | |
#[inline] | ||
#[target_feature(enable = "sse2")] | ||
#[cfg(feature = "cipher")] | ||
pub(crate) unsafe fn inner<R, F>(state: &mut [u32; STATE_WORDS], f: F) | ||
pub(crate) unsafe fn inner<R, V, F>(counter: &mut u64, state: &mut [u32; STATE_WORDS], f: F) | ||
where | ||
R: Rounds, | ||
V: Variant, | ||
F: StreamCipherClosure<BlockSize = U64>, | ||
{ | ||
let state_ptr = state.as_ptr() as *const __m128i; | ||
let mut backend = Backend::<R> { | ||
let mut backend = Backend::<R,V> { | ||
v: [ | ||
_mm_loadu_si128(state_ptr.add(0)), | ||
_mm_loadu_si128(state_ptr.add(1)), | ||
_mm_loadu_si128(state_ptr.add(2)), | ||
_mm_loadu_si128(state_ptr.add(3)), | ||
], | ||
counter: *counter, | ||
_pd: PhantomData, | ||
}; | ||
|
||
f.call(&mut backend); | ||
|
||
state[12] = _mm_cvtsi128_si32(backend.v[3]) as u32; | ||
*counter = backend.counter; | ||
|
||
if V::COUNTER_SIZE == 1 { | ||
state[12] = _mm_cvtsi128_si32(backend.v[3]) as u32; | ||
} else { | ||
let ctr = _mm_cvtsi128_si64(backend.v[3]) as u64; | ||
|
||
state[12] = (ctr&(u32::MAX as u64)) as u32; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the |
||
state[13] = (ctr>>32) as u32; | ||
} | ||
|
||
} | ||
|
||
struct Backend<R: Rounds> { | ||
struct Backend<R: Rounds, V: Variant> { | ||
v: [__m128i; 4], | ||
_pd: PhantomData<R>, | ||
counter: u64, | ||
_pd: PhantomData<(R,V)>, | ||
} | ||
|
||
#[cfg(feature = "cipher")] | ||
impl<R: Rounds> BlockSizeUser for Backend<R> { | ||
impl<R: Rounds,V: Variant> BlockSizeUser for Backend<R,V> { | ||
type BlockSize = U64; | ||
} | ||
|
||
#[cfg(feature = "cipher")] | ||
impl<R: Rounds> ParBlocksSizeUser for Backend<R> { | ||
impl<R: Rounds, V: Variant> ParBlocksSizeUser for Backend<R, V> { | ||
type ParBlocksSize = U4; | ||
} | ||
|
||
#[cfg(feature = "cipher")] | ||
impl<R: Rounds> StreamCipherBackend for Backend<R> { | ||
impl<R: Rounds, V: Variant> StreamCipherBackend for Backend<R, V> { | ||
#[inline(always)] | ||
fn gen_ks_block(&mut self, block: &mut Block) { | ||
self.counter = self.counter.saturating_add(1); | ||
unsafe { | ||
let res = rounds::<R>(&self.v); | ||
self.v[3] = _mm_add_epi32(self.v[3], _mm_set_epi32(0, 0, 0, 1)); | ||
let res = rounds::<R,V>(&self.v); | ||
if V::COUNTER_SIZE == 1 { | ||
let mask = _mm_add_epi32( | ||
_mm_and_si128( | ||
_mm_cmpeq_epi32(self.v[3], _mm_set_epi32(0,0,0,-1)), | ||
_mm_set_epi32(0,0,0,-1)), | ||
_mm_set_epi32(0,0,0,1)); | ||
self.v[3] = _mm_add_epi32(self.v[3], _mm_and_si128(mask, _mm_set_epi32(0, 0, 0, 1))); | ||
} else { | ||
let mask = _mm_add_epi64( | ||
_mm_and_si128( | ||
_mm_cmpeq_epi64(self.v[3], _mm_set_epi64x(0,-1)), | ||
_mm_set_epi64x(0,-1)), | ||
_mm_set_epi64x(0,1)); | ||
self.v[3] = _mm_add_epi64(self.v[3], _mm_and_si128(mask, _mm_set_epi64x(0, 1))); | ||
} | ||
|
||
let block_ptr = block.as_mut_ptr() as *mut __m128i; | ||
for i in 0..4 { | ||
|
@@ -75,9 +103,34 @@ impl<R: Rounds> StreamCipherBackend for Backend<R> { | |
} | ||
#[inline(always)] | ||
fn gen_par_ks_blocks(&mut self, blocks: &mut cipher::ParBlocks<Self>) { | ||
if V::COUNTER_SIZE == 1 { | ||
self.counter = core::cmp::min(V::MAX_USABLE_COUNTER+1, | ||
self.counter.saturating_add(PAR_BLOCKS as u64)); | ||
} else { | ||
self.counter = self.counter.saturating_add(PAR_BLOCKS as u64); | ||
} | ||
unsafe { | ||
let res = rounds::<R>(&self.v); | ||
self.v[3] = _mm_add_epi32(self.v[3], _mm_set_epi32(0, 0, 0, PAR_BLOCKS as i32)); | ||
let res = rounds::<R,V>(&self.v); | ||
if V::COUNTER_SIZE == 1 { | ||
|
||
let new_v3 = _mm_add_epi32(self.v[3], _mm_set_epi32(0, 0, 0, PAR_BLOCKS as i32)); | ||
let shifted = _mm_add_epi32(self.v[3], _mm_set_epi32(0,0,0,i32::MIN)); | ||
let new_shifted = _mm_add_epi32(new_v3, _mm_set_epi32(0,0,0,i32::MIN)); | ||
let mask = _mm_cmpgt_epi32(shifted,new_shifted); | ||
let max_val = _mm_and_si128(mask,_mm_set_epi32(0,0,0,-1)); | ||
let new_val = _mm_andnot_si128(mask,new_v3); | ||
self.v[3] = _mm_or_si128(max_val,new_val); | ||
|
||
} else { | ||
let new_v3 = _mm_add_epi64(self.v[3], _mm_set_epi64x(0, PAR_BLOCKS as i64)); | ||
|
||
let shifted = _mm_add_epi64(self.v[3], _mm_set_epi64x(0,i64::MIN)); | ||
let new_shifted = _mm_add_epi64(new_v3, _mm_set_epi64x(0,i64::MIN)); | ||
let mask = _mm_cmpgt_epi64(shifted,new_shifted); | ||
let max_val = _mm_and_si128(mask,_mm_set_epi64x(0,-1)); | ||
let new_val = _mm_andnot_si128(mask,new_v3); | ||
self.v[3] = _mm_or_si128(max_val,new_val); | ||
} | ||
|
||
let blocks_ptr = blocks.as_mut_ptr() as *mut __m128i; | ||
for block in 0..PAR_BLOCKS { | ||
|
@@ -98,28 +151,50 @@ where | |
V: Variant, | ||
{ | ||
let state_ptr = core.state.as_ptr() as *const __m128i; | ||
let mut backend = Backend::<R> { | ||
let mut backend = Backend::<R,V> { | ||
v: [ | ||
_mm_loadu_si128(state_ptr.add(0)), | ||
_mm_loadu_si128(state_ptr.add(1)), | ||
_mm_loadu_si128(state_ptr.add(2)), | ||
_mm_loadu_si128(state_ptr.add(3)), | ||
], | ||
counter: core.counter, | ||
_pd: PhantomData, | ||
}; | ||
|
||
backend.gen_ks_blocks(buffer); | ||
|
||
core.state[12] = _mm_cvtsi128_si32(backend.v[3]) as u32; | ||
core.counter = backend.counter; | ||
if V::COUNTER_SIZE == 1 { | ||
core.state[12] = _mm_cvtsi128_si32(backend.v[3]) as u32; | ||
} else { | ||
let ctr = _mm_cvtsi128_si64(backend.v[3]) as u64; | ||
|
||
core.state[12] = (ctr&(u32::MAX as u64)) as u32; | ||
core.state[13] = (ctr>>32) as u32; | ||
} | ||
} | ||
|
||
#[cfg(feature = "rng")] | ||
impl<R: Rounds> Backend<R> { | ||
impl<R: Rounds, V: Variant> Backend<R, V> { | ||
#[inline(always)] | ||
fn gen_ks_blocks(&mut self, block: &mut [u32]) { | ||
if V::COUNTER_SIZE == 1 { | ||
self.counter = V::MAX_USABLE_COUNTER & | ||
self.counter.saturating_add(PAR_BLOCKS as u64); | ||
} else { | ||
self.counter = self.counter.saturating_add(PAR_BLOCKS as u64); | ||
} | ||
|
||
unsafe { | ||
let res = rounds::<R>(&self.v); | ||
self.v[3] = _mm_add_epi32(self.v[3], _mm_set_epi32(0, 0, 0, PAR_BLOCKS as i32)); | ||
let res = rounds::<R,V>(&self.v); | ||
if V::COUNTER_SIZE == 1 { | ||
|
||
self.v[3] = _mm_add_epi32(self.v[3], _mm_set_epi32(0, 0, 0, PAR_BLOCKS as i32)); | ||
|
||
} else { | ||
self.v[3] = _mm_add_epi64(self.v[3], _mm_set_epi64x(0, PAR_BLOCKS as i64)); | ||
} | ||
|
||
let blocks_ptr = block.as_mut_ptr() as *mut __m128i; | ||
for block in 0..PAR_BLOCKS { | ||
|
@@ -133,10 +208,15 @@ impl<R: Rounds> Backend<R> { | |
|
||
#[inline] | ||
#[target_feature(enable = "sse2")] | ||
unsafe fn rounds<R: Rounds>(v: &[__m128i; 4]) -> [[__m128i; 4]; PAR_BLOCKS] { | ||
unsafe fn rounds<R: Rounds, V: Variant>(v: &[__m128i; 4]) -> [[__m128i; 4]; PAR_BLOCKS] { | ||
let mut res = [*v; 4]; | ||
for block in 1..PAR_BLOCKS { | ||
res[block][3] = _mm_add_epi32(res[block][3], _mm_set_epi32(0, 0, 0, block as i32)); | ||
if V::COUNTER_SIZE == 1 { | ||
res[block][3] = _mm_add_epi32(res[block][3], _mm_set_epi32(0, 0, 0, block as i32)); | ||
} else { | ||
res[block][3] = _mm_add_epi64(res[block][3], _mm_set_epi64x(0, block as i64)); | ||
} | ||
|
||
} | ||
|
||
for _ in 0..R::COUNT { | ||
|
@@ -148,7 +228,11 @@ unsafe fn rounds<R: Rounds>(v: &[__m128i; 4]) -> [[__m128i; 4]; PAR_BLOCKS] { | |
res[block][i] = _mm_add_epi32(res[block][i], v[i]); | ||
} | ||
// add the counter since `v` is lacking updated counter values | ||
res[block][3] = _mm_add_epi32(res[block][3], _mm_set_epi32(0, 0, 0, block as i32)); | ||
if V::COUNTER_SIZE == 1 { | ||
res[block][3] = _mm_add_epi32(res[block][3], _mm_set_epi32(0, 0, 0, block as i32)); | ||
} else { | ||
res[block][3] = _mm_add_epi64(res[block][3], _mm_set_epi64x(0, block as i64)); | ||
} | ||
} | ||
|
||
res | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -215,6 +215,8 @@ pub struct ChaChaCore<R: Rounds, V: Variant> { | |
/// CPU target feature tokens | ||
#[allow(dead_code)] | ||
tokens: Tokens, | ||
/// Current counter position | ||
counter: u64, | ||
Comment on lines
+218
to
+219
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need to add additional state to This is just for testing? |
||
/// Number of rounds to perform | ||
rounds: PhantomData<R>, | ||
/// the variant of the implementation | ||
|
@@ -254,9 +256,11 @@ impl<R: Rounds, V: Variant> ChaChaCore<R, V> { | |
let tokens = (); | ||
} | ||
} | ||
debug_assert_eq!(state[12], 0); | ||
Self { | ||
state, | ||
tokens, | ||
counter: 0, | ||
rounds: PhantomData, | ||
variant: PhantomData, | ||
} | ||
|
@@ -265,24 +269,40 @@ impl<R: Rounds, V: Variant> ChaChaCore<R, V> { | |
|
||
#[cfg(feature = "cipher")] | ||
impl<R: Rounds, V: Variant> StreamCipherSeekCore for ChaChaCore<R, V> { | ||
type Counter = u32; | ||
type Counter = u64; | ||
Comment on lines
-268
to
+272
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should depend on the variant. |
||
|
||
#[inline(always)] | ||
fn get_block_pos(&self) -> Self::Counter { | ||
self.state[12] | ||
let le_pos = self.counter.to_le_bytes(); | ||
let max_val = V::MAX_USABLE_COUNTER; | ||
if self.counter <= max_val { | ||
for i in 0..V::COUNTER_SIZE { | ||
debug_assert_eq!( | ||
self.state[12 + i], | ||
u32::from_le_bytes(<_>::try_from(&le_pos[4 * i..(4 * (i + 1))]).unwrap()) | ||
); | ||
} | ||
} | ||
self.counter | ||
Comment on lines
+276
to
+286
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Assuming COUNTER_SIZE <= 2, it's much simpler to do this: let mut counter = self.state[12];
if V::COUNTER_SIZE > 1 {
counter |= self.state[13] << 32;
} |
||
} | ||
|
||
#[inline(always)] | ||
fn set_block_pos(&mut self, pos: Self::Counter) { | ||
self.state[12] = pos | ||
self.counter = pos; | ||
let le_pos = pos.to_le_bytes(); | ||
for i in 0..V::COUNTER_SIZE { | ||
self.state[12 + i] = | ||
u32::from_le_bytes(<_>::try_from(&le_pos[4 * i..(4 * (i + 1))]).unwrap()); | ||
} | ||
} | ||
} | ||
|
||
#[cfg(feature = "cipher")] | ||
impl<R: Rounds, V: Variant> StreamCipherCore for ChaChaCore<R, V> { | ||
#[inline(always)] | ||
fn remaining_blocks(&self) -> Option<usize> { | ||
let rem = u32::MAX - self.get_block_pos(); | ||
let max_val = V::MAX_USABLE_COUNTER; | ||
let rem = max_val.saturating_sub(self.get_block_pos()); | ||
rem.try_into().ok() | ||
} | ||
|
||
|
@@ -301,7 +321,7 @@ impl<R: Rounds, V: Variant> StreamCipherCore for ChaChaCore<R, V> { | |
} | ||
} else if #[cfg(chacha20_force_sse2)] { | ||
unsafe { | ||
backends::sse2::inner::<R, _>(&mut self.state, f); | ||
backends::sse2::inner::<R, V, _>(&mut self.counter, &mut self.state, f); | ||
} | ||
} else { | ||
let (avx2_token, sse2_token) = self.tokens; | ||
|
@@ -311,7 +331,7 @@ impl<R: Rounds, V: Variant> StreamCipherCore for ChaChaCore<R, V> { | |
} | ||
} else if sse2_token.get() { | ||
unsafe { | ||
backends::sse2::inner::<R, _>(&mut self.state, f); | ||
backends::sse2::inner::<R, V, _>(&mut self.counter, &mut self.state, f); | ||
} | ||
} else { | ||
f.call(&mut backends::soft::Backend(self)); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Old behaviour when COUNTER_SIZE=1 was to wrap. New behaviour is to saturate, but I don't think this is justified?
This also saturates when COUNTER_SIZE=2. Arguably it should wrap just in case someone set the block position to something just below the end.
Hint: it may be easier to convert to
u64
, increment and convert back, or what you did a few lines below.