diff --git a/tlvc/src/lib.rs b/tlvc/src/lib.rs index 00b64d5..bc80386 100644 --- a/tlvc/src/lib.rs +++ b/tlvc/src/lib.rs @@ -50,8 +50,21 @@ impl ChunkHeader { /// Compute the length of this chunk in bytes, including the header, /// body, any padding, and the trailing checksum. + /// + /// ## Panics + /// + /// If the chunk's total length cannot be represented as a usize. pub fn total_len_in_bytes(&self) -> usize { - size_of::() + round_up_usize(self.len.get() as usize) + 4 + // Note: it would be nice to avoid the panic here, but it's not + // possible because the len field is public and can be mutated. + // Preferably the field(s) would be behind getters, in which case + // creation of ChunkHeader could check that the length field is + // reasonable. + const HEADER_AND_CHECKSUM_BYTES: usize = + size_of::() + size_of::(); + round_up_usize(self.len.get() as usize) + .and_then(|l| l.checked_add(HEADER_AND_CHECKSUM_BYTES)) + .unwrap() } } @@ -76,7 +89,7 @@ pub trait TlvcRead: Clone { impl TlvcRead for &'_ [u8] { type Error = core::convert::Infallible; fn extent(&self) -> Result> { - Ok(u64::try_from(self.len()).unwrap()) + u64::try_from(self.len()).map_err(|_| TlvcReadError::Truncated) } fn read_exact( @@ -84,8 +97,11 @@ impl TlvcRead for &'_ [u8] { offset: u64, dest: &mut [u8], ) -> Result<(), TlvcReadError> { - let offset = usize::try_from(offset).unwrap(); - let end = offset.checked_add(dest.len()).unwrap(); + let offset = + usize::try_from(offset).map_err(|_| TlvcReadError::Truncated)?; + let end = offset + .checked_add(dest.len()) + .ok_or(TlvcReadError::Truncated)?; dest.copy_from_slice(&self[offset..end]); Ok(()) } @@ -150,7 +166,7 @@ impl TlvcReader { /// Returns the number of bytes remaining in this reader. pub fn remaining(&self) -> u64 { - self.limit - self.position + self.limit.saturating_sub(self.position) } /// Destroys this reader and returns the original `source`, the byte @@ -188,15 +204,29 @@ impl TlvcReader { } let header = self.read_header()?; - let body_position = self.position + size_of::() as u64; - let body_len = round_up(u64::from(header.len.get())); - let chunk_end = body_position + body_len + 4; + // SAFETY: read_header has performed checked_add on the same values and + // returned an Err(TlvcReadError::Truncated) if this did overflow. + let body_position = unsafe { + self.position.unchecked_add(size_of::() as u64) + }; + // Note: this cannot overflow as we go from a u32 to a u64, and the + // compiler sees it too and removes the panic branch here. + let body_and_checksum_size = round_up_u32_to_u64(header.len.get()) + 4; + // Note: ChunkHandle::read_as_chunks assumes that this check is + // performed. Removing this would make the SAFETY comment there invalid + // and risk undefined behaviour. + let chunk_end = body_position + .checked_add(body_and_checksum_size) + .ok_or(TlvcReadError::Truncated)?; if chunk_end > self.limit { return Err(TlvcReadError::Truncated); } self.position = chunk_end; + // Note: ChunkHandle::read_as_chunks assumes that this is the only + // code path creating handles, and that above + // body_position + header.len is checked to not overflow. Ok(Some(ChunkHandle { source: self.source.clone(), header, @@ -245,25 +275,24 @@ impl TlvcReader { pub fn skip_chunk(&mut self) -> Result<(), TlvcReadError> { let h = self.read_header()?; - // Compute the overall size of the header, contents (rounded up for - // alignment), and the trailing checksum (which we're not going to - // check). - let size = size_of::() as u64 - + round_up(u64::from(h.len.get())) - + size_of::() as u64; - + // Compute the overall size of the contents (rounded up for alignment), + // header, and the trailing checksum (which we're not going to check). + // Note: this cannot overflow as we go from a u32 to a u64, and the + // compiler sees it too and removes the panic branch here. + let chunk_size = round_up_u32_to_u64(h.len.get()) + + (size_of::() + size_of::()) as u64; // Bump our new position forward as long as it doesn't cross our limit. // This may leave us zero-length. That's ok. - let p = self + let chunk_end = self .position - .checked_add(size) + .checked_add(chunk_size) .ok_or(TlvcReadError::Truncated)?; - if p > self.limit { + if chunk_end > self.limit { return Err(TlvcReadError::Truncated); } - self.position = p; + self.position = chunk_end; Ok(()) } @@ -321,13 +350,21 @@ impl ChunkHandle { R: TlvcRead, { let end = position - .checked_add(u64::try_from(dest.len()).unwrap()) + .checked_add( + u64::try_from(dest.len()) + .map_err(|_| TlvcReadError::Truncated)?, + ) .ok_or(TlvcReadError::Truncated)?; if end > self.len() { return Err(TlvcReadError::Truncated); } - self.source.read_exact(self.body_position + position, dest) + let offset = self + .body_position + .checked_add(position) + .ok_or(TlvcReadError::Truncated)?; + + self.source.read_exact(offset, dest) } /// Produces a `TlvcReader` that can be used to interpret this chunk's body @@ -347,7 +384,14 @@ impl ChunkHandle { TlvcReader { source: self.source.clone(), position: self.body_position, - limit: self.body_position + u64::from(self.header.len.get()), + // SAFETY: creation of ChunkHandle in TlvcReader::next checks that + // body_position + header.len does not overflow. ChunkHandle has no + // '&mut self' methods and the fields are private, so this addition + // still cannot overflow. + limit: unsafe { + self.body_position + .unchecked_add(u64::from(self.header.len.get())) + }, } } @@ -363,22 +407,34 @@ impl ChunkHandle { where R: TlvcRead, { + // Caclulate the body checksum. let mut c = begin_body_crc(); - let end = self.body_position + self.header.len.get() as u64; - let mut pos = self.body_position; - while pos != end { - let portion = usize::try_from(end - pos) - .unwrap_or(usize::MAX) - .min(buffer.len()); - self.source.read_exact(pos, &mut buffer[..portion])?; - c.update(&buffer[..portion]); - pos += u64::try_from(portion).unwrap(); + let mut pos = usize::try_from(self.body_position) + .map_err(|_| TlvcReadError::Truncated)?; + let contents_len = usize::try_from(self.header.len.get()) + .map_err(|_| TlvcReadError::Truncated)?; + let end = pos + .checked_add(contents_len) + .ok_or(TlvcReadError::Truncated)?; + let len = buffer.len(); + while pos < end { + let portion = (end - pos).min(len); + let buf = &mut buffer[..portion]; + self.source.read_exact( + u64::try_from(pos).map_err(|_| TlvcReadError::Truncated)?, + buf, + )?; + c.update(buf); + pos += portion; } - let computed_checksum = c.finalize(); + + // Read the stored checksum at the end of the chunk and compare. let mut stored_checksum = 0u32; - self.source - .read_exact(round_up(end), stored_checksum.as_bytes_mut())?; + self.source.read_exact( + round_up_usize_to_u64(end).ok_or(TlvcReadError::Truncated)?, + stored_checksum.as_bytes_mut(), + )?; if computed_checksum != stored_checksum { Err(TlvcReadError::BodyCorrupt { @@ -407,12 +463,19 @@ pub fn compute_body_crc(data: &[u8]) -> u32 { c.finalize() } -fn round_up(x: u64) -> u64 { - (x + 0b11) & !0b11 +#[inline(always)] +fn round_up_u32_to_u64(x: u32) -> u64 { + (x as u64 + 0b11) & !0b11 +} + +#[inline(always)] +fn round_up_usize(x: usize) -> Option { + Some(x.checked_add(0b11)? & !0b11) } -fn round_up_usize(x: usize) -> usize { - (x + 0b11) & !0b11 +#[inline(always)] +fn round_up_usize_to_u64(x: usize) -> Option { + Some((u64::try_from(x).ok()?).checked_add(0b11)? & !0b11) } #[cfg(test)]