Skip to content
232 changes: 149 additions & 83 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -638,62 +638,118 @@ fn read_chunk(bytes: &[u8]) -> Result<Chunk, DecodeError> {
Ok(Chunk { chunk_type, data: &data_for_crc[4..], crc })
}

fn defilter(
// BPP = Bytes Per Pixel
fn defilter<const BPP: usize>(
filter_type: FilterType,
bytes_per_pixel: usize,
x: usize,
current_scanline: &[u8],
current_scanline: &mut [u8],
last_scanline: &[u8],
) -> u8 {
) {
match filter_type {
FilterType::None => current_scanline[x],
FilterType::None => {},
FilterType::Sub => {
if let Some(idx) = x.checked_sub(bytes_per_pixel) {
current_scanline[x].wrapping_add(current_scanline[idx])
} else {
current_scanline[x]
let mut chunk_iter = current_scanline.chunks_exact_mut(BPP);
let mut prev_chunk = chunk_iter.next().unwrap();

for current_chunk in &mut chunk_iter {
for (current_byte, prev_byte) in current_chunk.iter_mut().zip(prev_chunk) {
*current_byte = current_byte.wrapping_add(*prev_byte);
}

prev_chunk = current_chunk;
}
},
FilterType::Up => {
for (current, above) in (current_scanline.iter_mut()).zip(last_scanline) {
*current = current.wrapping_add(*above);
}
},
FilterType::Up => current_scanline[x].wrapping_add(last_scanline[x]),
FilterType::Average => {
let raw_val = if let Some(idx) = x.checked_sub(bytes_per_pixel) {
current_scanline[idx]
} else {
0
};
for x in 0..(BPP) {
current_scanline[x] = current_scanline[x].wrapping_add((last_scanline[x]) / 2);
}

let mut chunk_iter = current_scanline.chunks_exact_mut(BPP);
let mut left_chunk = chunk_iter.next().unwrap();

let upper_iter = last_scanline[BPP..].chunks_exact(BPP);

(current_scanline[x] as u16 + ((raw_val as u16 + last_scanline[x] as u16) / 2)) as u8
for (current_chunk, upper_chunk) in (&mut chunk_iter).zip(upper_iter) {
for i in 0..BPP {
let left_byte = left_chunk[i];
let upper_byte = upper_chunk[i];

current_chunk[i] = current_chunk[i]
.wrapping_add(((left_byte as u16 + upper_byte as u16) / 2) as u8);
}

left_chunk = current_chunk;
}
},
FilterType::Paeth => {
if let Some(idx) = x.checked_sub(bytes_per_pixel) {
let left = current_scanline[idx];
let above = last_scanline[x];
let upper_left = last_scanline[idx];
for x in 0..(BPP) {
let predictor = paeth_predictor(0, last_scanline[x] as i16, 0);
current_scanline[x] = current_scanline[x].wrapping_add(predictor);
}

let mut chunk_iter = current_scanline.chunks_exact_mut(BPP);
let mut left_chunk = chunk_iter.next().unwrap();

let upper_left_iter = last_scanline.chunks_exact(BPP);
let upper_iter = last_scanline[BPP..].chunks_exact(BPP);

let predictor = paeth_predictor(left as i16, above as i16, upper_left as i16);
for ((current_chunk, upper_left_chunk), upper_chunk) in
(&mut chunk_iter).zip(upper_left_iter).zip(upper_iter)
{
for i in 0..BPP {
let left_byte = left_chunk[i];
let upper_left_byte = upper_left_chunk[i];
let upper_byte = upper_chunk[i];

current_scanline[x].wrapping_add(predictor)
} else {
let left = 0;
let above = last_scanline[x];
let upper_left = 0;
let predictor = paeth_predictor(
left_byte as i16,
upper_byte as i16,
upper_left_byte as i16,
);

let predictor = paeth_predictor(left as i16, above as i16, upper_left as i16);
current_chunk[i] = current_chunk[i].wrapping_add(predictor);
}

current_scanline[x].wrapping_add(predictor)
left_chunk = current_chunk;
}
},
}
}

#[inline(always)]
fn paeth_predictor(a: i16, b: i16, c: i16) -> u8 {
// TODO(bschwind) - Accept i16 or convert once and store in a temp.
// a = left pixel
// b = above pixel
// c = upper left
let p = a + b - c;
let pa = (p - a).abs();
let pb = (p - b).abs();
let pc = (p - c).abs();

let first = pa <= pb && pa <= pc;
let first_bitmask = first as u8 * 255u8;

let second = !first && pb <= pc;
let second_bitmask = second as u8 * 255u8;

let third = !first && !second;
let third_bitmask = third as u8 * 255u8;

(first_bitmask & a as u8) | (second_bitmask & b as u8) | (third_bitmask & c as u8)
}

fn process_scanlines(
header: &PngHeader,
scanline_data: &mut [u8],
mut scanline_data: &mut [u8],
output_rgba: &mut [u8],
ancillary_chunks: &AncillaryChunks,
pixel_type: PixelType,
) -> Result<(), DecodeError> {
let mut cursor = 0;
let bytes_per_pixel: usize =
((header.bit_depth as usize * header.color_type.sample_multiplier()) + 7) / 8;

Expand All @@ -708,21 +764,34 @@ fn process_scanlines(
let bytes_per_scanline: usize =
bytes_per_scanline.try_into().map_err(|_| DecodeError::IntegerOverflow)?;

let mut last_scanline = vec![0u8; bytes_per_scanline];
let zero_scanline = vec![0u8; bytes_per_scanline];
let mut last_scanline: &[u8] = &zero_scanline;

let mut total_defilter = std::time::Duration::from_secs(0);
let mut total_scanline = std::time::Duration::from_secs(0);

for y in 0..header.height {
let filter_type = FilterType::try_from(scanline_data[cursor])
let filter_type = FilterType::try_from(scanline_data[0])
.map_err(|_| DecodeError::InvalidFilterType)?;
cursor += 1;

let current_scanline = &mut scanline_data[cursor..(cursor + bytes_per_scanline)];
let (current_scanline, scanline_data_tail) =
scanline_data[1..].split_at_mut(bytes_per_scanline);

for x in 0..(bytes_per_scanline) {
let unfiltered_byte =
defilter(filter_type, bytes_per_pixel, x, current_scanline, &last_scanline);
current_scanline[x] = unfiltered_byte;
let now = std::time::Instant::now();

match bytes_per_pixel {
1 => defilter::<1>(filter_type, current_scanline, &last_scanline),
2 => defilter::<2>(filter_type, current_scanline, &last_scanline),
3 => defilter::<3>(filter_type, current_scanline, &last_scanline),
4 => defilter::<4>(filter_type, current_scanline, &last_scanline),
6 => defilter::<6>(filter_type, current_scanline, &last_scanline),
8 => defilter::<8>(filter_type, current_scanline, &last_scanline),
_ => {},
}

total_defilter += now.elapsed();

let now = std::time::Instant::now();
let scanline_iter = ScanlineIterator::new(
header.width,
pixel_type,
Expand All @@ -744,13 +813,19 @@ fn process_scanlines(
output_rgba[output_idx + 3] = a;
}

last_scanline.copy_from_slice(current_scanline);
cursor += bytes_per_scanline;
total_scanline += now.elapsed();

last_scanline = current_scanline;
scanline_data = scanline_data_tail;
}

println!("total_defilter took {:?}", total_defilter);
println!("total_scanline took {:?}", total_scanline);
},
InterlaceMethod::Adam7 => {
let max_bytes_per_scanline = header.width as usize * bytes_per_pixel;
let mut last_scanline = vec![0u8; max_bytes_per_scanline];

let zero_scanline = vec![0u8; max_bytes_per_scanline];

// Adam7 Interlacing Pattern
// 1 6 4 6 2 6 4 6
Expand Down Expand Up @@ -815,28 +890,23 @@ fn process_scanlines(
let bytes_per_scanline: usize =
bytes_per_scanline.try_into().expect("bytes_per_scanline overflowed a usize");

let last_scanline = &mut last_scanline[..(bytes_per_scanline)];
for byte in last_scanline.iter_mut() {
*byte = 0;
}
let mut last_scanline = &zero_scanline[..(bytes_per_scanline)];

for y in 0..pass_height {
let filter_type = FilterType::try_from(scanline_data[cursor])
let filter_type = FilterType::try_from(scanline_data[0])
.map_err(|_| DecodeError::InvalidFilterType)?;
cursor += 1;

let current_scanline =
&mut scanline_data[cursor..(cursor + bytes_per_scanline)];

for x in 0..(bytes_per_scanline) {
let unfiltered_byte = defilter(
filter_type,
bytes_per_pixel,
x,
current_scanline,
&last_scanline,
);
current_scanline[x] = unfiltered_byte;

let (current_scanline, scanline_data_tail) =
scanline_data[1..].split_at_mut(bytes_per_scanline);

match bytes_per_pixel {
1 => defilter::<1>(filter_type, current_scanline, &last_scanline),
2 => defilter::<2>(filter_type, current_scanline, &last_scanline),
3 => defilter::<3>(filter_type, current_scanline, &last_scanline),
4 => defilter::<4>(filter_type, current_scanline, &last_scanline),
6 => defilter::<6>(filter_type, current_scanline, &last_scanline),
8 => defilter::<8>(filter_type, current_scanline, &last_scanline),
_ => {},
}

let scanline_iter = ScanlineIterator::new(
Expand Down Expand Up @@ -870,9 +940,8 @@ fn process_scanlines(
output_rgba[output_idx + 3] = a;
}

last_scanline.copy_from_slice(current_scanline);

cursor += bytes_per_scanline;
last_scanline = current_scanline;
scanline_data = scanline_data_tail;
}
}
},
Expand All @@ -881,25 +950,6 @@ fn process_scanlines(
Ok(())
}

fn paeth_predictor(a: i16, b: i16, c: i16) -> u8 {
// TODO(bschwind) - Accept i16 or convert once and store in a temp.
// a = left pixel
// b = above pixel
// c = upper left
let p = a + b - c;
let pa = (p - a).abs();
let pb = (p - b).abs();
let pc = (p - c).abs();

if pa <= pb && pa <= pc {
a as u8
} else if pb <= pc {
b as u8
} else {
c as u8
}
}

pub fn decode(bytes: &[u8]) -> Result<(PngHeader, Vec<u8>), DecodeError> {
if bytes.len() < PNG_MAGIC_BYTES.len() {
return Err(DecodeError::MissingBytes);
Expand All @@ -922,6 +972,7 @@ pub fn decode(bytes: &[u8]) -> Result<(PngHeader, Vec<u8>), DecodeError> {
let pixel_type = PixelType::new(header.color_type, header.bit_depth)?;
let mut ancillary_chunks = AncillaryChunks::default();

let now = std::time::Instant::now();
while !bytes.is_empty() {
let chunk = read_chunk(bytes)?;

Expand All @@ -938,19 +989,26 @@ pub fn decode(bytes: &[u8]) -> Result<(PngHeader, Vec<u8>), DecodeError> {
bytes = &bytes[chunk.byte_size()..];
}

println!("Chunk reading took {:?}", now.elapsed());

let now = std::time::Instant::now();
let mut scanline_data = miniz_oxide::inflate::decompress_to_vec_zlib(&compressed_data)
.map_err(DecodeError::Decompress)?;

println!("Decompress took {:?}", now.elapsed());

// For now, output data is always RGBA, 1 byte per channel.
let mut output_rgba = vec![0u8; header.width as usize * header.height as usize * 4];

let now = std::time::Instant::now();
process_scanlines(
&header,
&mut scanline_data,
&mut output_rgba,
&ancillary_chunks,
pixel_type,
)?;
println!("process_scanlines took {:?}", now.elapsed());

Ok((header, output_rgba))
}
Expand Down Expand Up @@ -1013,4 +1071,12 @@ mod tests {
}
}
}

#[test]
fn hd_decode_test() {
let png_bytes = include_bytes!("../test_pngs/skyline.png");
let now = std::time::Instant::now();
let (_header, _decoded) = decode(png_bytes).unwrap();
println!("Took {:?}", now.elapsed());
}
}