diff --git a/src/fixed_vector.rs b/src/fixed_vector.rs index 3fe9e71..afa598c 100644 --- a/src/fixed_vector.rs +++ b/src/fixed_vector.rs @@ -2,7 +2,9 @@ use crate::tree_hash::vec_tree_hash_root; use crate::Error; use serde::Deserialize; use serde_derive::Serialize; +use std::any::TypeId; use std::marker::PhantomData; +use std::mem; use std::ops::{Deref, DerefMut, Index, IndexMut}; use std::slice::SliceIndex; use tree_hash::Hash256; @@ -283,7 +285,7 @@ impl ssz::TryFromIter for FixedVector { impl ssz::Decode for FixedVector where - T: ssz::Decode, + T: ssz::Decode + 'static, { fn is_ssz_fixed_len() -> bool { T::is_ssz_fixed_len() @@ -305,6 +307,24 @@ where len: 0, expected: 1, }) + } else if TypeId::of::() == TypeId::of::() { + if bytes.len() != fixed_len { + return Err(ssz::DecodeError::BytesInvalid(format!( + "FixedVector of {} items has {} items", + fixed_len, + bytes.len(), + ))); + } + + // Safety: We've verified T is u8, so Vec *is* Vec. + let vec_u8 = bytes.to_vec(); + let vec_t = unsafe { mem::transmute::, Vec>(vec_u8) }; + Self::new(vec_t).map_err(|e| { + ssz::DecodeError::BytesInvalid(format!( + "Wrong number of FixedVector elements: {:?}", + e + )) + }) } else if T::is_ssz_fixed_len() { let num_items = bytes .len() @@ -314,17 +334,24 @@ where if num_items != fixed_len { return Err(ssz::DecodeError::BytesInvalid(format!( "FixedVector of {} items has {} items", - num_items, fixed_len + fixed_len, num_items ))); } - let vec = bytes.chunks(T::ssz_fixed_len()).try_fold( - Vec::with_capacity(num_items), - |mut vec, chunk| { - vec.push(T::from_ssz_bytes(chunk)?); - Ok(vec) - }, - )?; + // Check that we have a whole number of items and that it is safe to use chunks_exact + if !bytes.len().is_multiple_of(T::ssz_fixed_len()) { + return Err(ssz::DecodeError::BytesInvalid(format!( + "FixedVector of {} items has {} bytes", + num_items, + bytes.len() + ))); + } + + let mut vec = Vec::with_capacity(num_items); + for chunk in bytes.chunks_exact(T::ssz_fixed_len()) { + vec.push(T::from_ssz_bytes(chunk)?); + } + Self::new(vec).map_err(|e| { ssz::DecodeError::BytesInvalid(format!( "Wrong number of FixedVector elements: {:?}", @@ -479,6 +506,56 @@ mod test { ssz_round_trip::>(vec![0; 8].try_into().unwrap()); } + // Test byte decoding (we have a specialised code path with unsafe code that NEEDS coverage). + #[test] + fn ssz_round_trip_u8_len_1024() { + ssz_round_trip::>(vec![42; 1024].try_into().unwrap()); + ssz_round_trip::>(vec![0; 1024].try_into().unwrap()); + } + + // bool is layout equivalent to u8 but must not use the same unsafe codepath because not all u8 + // values are valid bools. + #[test] + fn ssz_round_trip_bool_len_1024() { + assert_eq!(mem::size_of::(), 1); + assert_eq!(mem::align_of::(), 1); + ssz_round_trip::>(vec![true; 1024].try_into().unwrap()); + ssz_round_trip::>(vec![false; 1024].try_into().unwrap()); + } + + // Decoding a u8 vector as a vector of bools must fail, if we aren't careful we could trigger UB. + #[test] + fn ssz_u8_to_bool_len_1024() { + let list_u8 = FixedVector::::new(vec![0, 1, 2, 3, 4, 5, 6, 7]).unwrap(); + FixedVector::::from_ssz_bytes(&list_u8.as_ssz_bytes()).unwrap_err(); + } + + #[test] + fn ssz_u8_len_1024_too_long() { + assert_eq!( + FixedVector::::from_ssz_bytes(&vec![42; 1025]).unwrap_err(), + ssz::DecodeError::BytesInvalid("FixedVector of 1024 items has 1025 items".into()) + ); + } + + #[test] + fn ssz_u64_len_1024_too_long() { + assert_eq!( + FixedVector::::from_ssz_bytes(&vec![42; 8 * 1025]).unwrap_err(), + ssz::DecodeError::BytesInvalid("FixedVector of 1024 items has 1025 items".into()) + ); + } + + // Decoding an input with invalid trailing bytes MUST fail. + #[test] + fn ssz_bytes_u64_trailing() { + let bytes = [1, 0, 0, 0, 2, 0, 0, 0, 1]; + assert_eq!( + FixedVector::::from_ssz_bytes(&bytes).unwrap_err(), + ssz::DecodeError::BytesInvalid("FixedVector of 2 items has 9 bytes".into()) + ); + } + #[test] fn tree_hash_u8() { let fixed: FixedVector = FixedVector::try_from(vec![]).unwrap(); diff --git a/src/variable_list.rs b/src/variable_list.rs index f98ba8f..a3998fc 100644 --- a/src/variable_list.rs +++ b/src/variable_list.rs @@ -2,7 +2,9 @@ use crate::tree_hash::vec_tree_hash_root; use crate::Error; use serde::Deserialize; use serde_derive::Serialize; +use std::any::TypeId; use std::marker::PhantomData; +use std::mem; use std::ops::{Deref, DerefMut, Index, IndexMut}; use std::slice::SliceIndex; use tree_hash::Hash256; @@ -288,7 +290,7 @@ impl ssz::TryFromIter for VariableList { impl ssz::Decode for VariableList where - T: ssz::Decode, + T: ssz::Decode + 'static, N: Unsigned, { fn is_ssz_fixed_len() -> bool { @@ -302,6 +304,26 @@ where return Ok(Self::default()); } + if TypeId::of::() == TypeId::of::() { + if bytes.len() > max_len { + return Err(ssz::DecodeError::BytesInvalid(format!( + "VariableList of {} items exceeds maximum of {}", + bytes.len(), + max_len + ))); + } + + // Safety: We've verified T is u8, so Vec *is* Vec. + let vec_u8 = bytes.to_vec(); + let vec_t = unsafe { mem::transmute::, Vec>(vec_u8) }; + return Self::new(vec_t).map_err(|e| { + ssz::DecodeError::BytesInvalid(format!( + "Wrong number of VariableList elements: {:?}", + e + )) + }); + } + if T::is_ssz_fixed_len() { let num_items = bytes .len() @@ -315,20 +337,28 @@ where ))); } - bytes.chunks(T::ssz_fixed_len()).try_fold( - Vec::with_capacity(num_items), - |mut vec, chunk| { - vec.push(T::from_ssz_bytes(chunk)?); - Ok(vec) - }, - ) + // Check that we have a whole number of items and that it is safe to use chunks_exact + if !bytes.len().is_multiple_of(T::ssz_fixed_len()) { + return Err(ssz::DecodeError::BytesInvalid(format!( + "VariableList of {} items has {} bytes", + num_items, + bytes.len() + ))); + } + + let mut vec = Vec::with_capacity(num_items); + for chunk in bytes.chunks_exact(T::ssz_fixed_len()) { + vec.push(T::from_ssz_bytes(chunk)?); + } + Self::new(vec).map_err(|e| { + ssz::DecodeError::BytesInvalid(format!( + "Wrong number of VariableList elements: {:?}", + e + )) + }) } else { ssz::decode_list_of_variable_length_items(bytes, Some(max_len)) - }? - .try_into() - .map_err(|e| { - ssz::DecodeError::BytesInvalid(format!("VariableList::try_from failed: {e:?}")) - }) + } } } @@ -452,7 +482,7 @@ mod test { assert_eq!( as Encode>::ssz_fixed_len(), 4); } - fn round_trip(item: T) { + fn ssz_round_trip(item: T) { let encoded = &item.as_ssz_bytes(); assert_eq!(item.ssz_bytes_len(), encoded.len()); assert_eq!(T::from_ssz_bytes(encoded), Ok(item)); @@ -460,9 +490,52 @@ mod test { #[test] fn u16_len_8() { - round_trip::>(vec![42; 8].try_into().unwrap()); - round_trip::>(vec![0; 8].try_into().unwrap()); - round_trip::>(vec![].try_into().unwrap()); + ssz_round_trip::>(vec![42; 8].try_into().unwrap()); + ssz_round_trip::>(vec![0; 8].try_into().unwrap()); + ssz_round_trip::>(vec![].try_into().unwrap()); + } + + #[test] + fn ssz_round_trip_u8_len_1024() { + ssz_round_trip::>(vec![42; 1024].try_into().unwrap()); + ssz_round_trip::>(vec![0; 1024].try_into().unwrap()); + } + + // bool is layout equivalent to u8 but must not use the same unsafe codepath because not all u8 + // values are valid bools. + #[test] + fn ssz_round_trip_bool_len_1024() { + assert_eq!(mem::size_of::(), 1); + assert_eq!(mem::align_of::(), 1); + ssz_round_trip::>(vec![true; 1024].try_into().unwrap()); + ssz_round_trip::>(vec![false; 1024].try_into().unwrap()); + } + + // Decoding a u8 list as a list of bools must fail, if we aren't careful we could trigger UB. + #[test] + fn ssz_u8_to_bool_len_1024() { + let list_u8 = VariableList::::new(vec![0, 1, 2, 3, 4, 5, 6, 7]).unwrap(); + VariableList::::from_ssz_bytes(&list_u8.as_ssz_bytes()).unwrap_err(); + } + + #[test] + fn ssz_u8_len_1024_too_long() { + assert_eq!( + VariableList::::from_ssz_bytes(&vec![42; 1025]).unwrap_err(), + ssz::DecodeError::BytesInvalid( + "VariableList of 1025 items exceeds maximum of 1024".into() + ) + ); + } + + #[test] + fn ssz_u64_len_1024_too_long() { + assert_eq!( + VariableList::::from_ssz_bytes(&vec![42; 8 * 1025]).unwrap_err(), + ssz::DecodeError::BytesInvalid( + "VariableList of 1025 items exceeds maximum of 1024".into() + ) + ); } #[test] @@ -473,6 +546,21 @@ mod test { assert_eq!(VariableList::from_ssz_bytes(&[]).unwrap(), empty_list); } + #[test] + fn ssz_bytes_u32_trailing() { + let bytes = [1, 0, 0, 0, 2, 0]; + assert_eq!( + VariableList::::from_ssz_bytes(&bytes).unwrap_err(), + ssz::DecodeError::BytesInvalid("VariableList of 1 items has 6 bytes".into()) + ); + + let bytes = [1, 0, 0, 0, 2, 0, 0, 0, 3]; + assert_eq!( + VariableList::::from_ssz_bytes(&bytes).unwrap_err(), + ssz::DecodeError::BytesInvalid("VariableList of 2 items has 9 bytes".into()) + ); + } + fn root_with_length(bytes: &[u8], len: usize) -> Hash256 { let root = merkle_root(bytes, 0); tree_hash::mix_in_length(&root, len)