From 076242a4842483399b61c1ff5170fa0e915d2cc0 Mon Sep 17 00:00:00 2001 From: Folkert de Vries Date: Tue, 24 Dec 2024 18:25:57 +0100 Subject: [PATCH 1/2] store function pointer for the correct `compare256` in a static when e.g. the avx2 target feature is not enabled at compile time, but the feature is available at runtime, this approach reduces branching. We still dispatch statically if the target feature is already enabled at compile time --- zlib-rs/src/deflate/compare256.rs | 68 +++++++++++++++++++++++++------ 1 file changed, 56 insertions(+), 12 deletions(-) diff --git a/zlib-rs/src/deflate/compare256.rs b/zlib-rs/src/deflate/compare256.rs index 4c531caf..ace564c3 100644 --- a/zlib-rs/src/deflate/compare256.rs +++ b/zlib-rs/src/deflate/compare256.rs @@ -1,6 +1,7 @@ #[cfg(test)] const MAX_COMPARE_SIZE: usize = 256; +#[inline(always)] pub fn compare256_slice(src0: &[u8], src1: &[u8]) -> usize { let src0 = first_chunk::<_, 256>(src0).unwrap(); let src1 = first_chunk::<_, 256>(src1).unwrap(); @@ -8,23 +9,66 @@ pub fn compare256_slice(src0: &[u8], src1: &[u8]) -> usize { compare256(src0, src1) } +/// Call the most optimal compare256 +/// +/// We attempt to call a specific version if its target feature is enabled at compile time +/// (e.g. via `-Ctarget-cpu`). If the desired target feature is not found, we defer to +/// [`compare256_via_function_pointer`]. +#[inline(always)] fn compare256(src0: &[u8; 256], src1: &[u8; 256]) -> usize { - #[cfg(target_arch = "x86_64")] - if crate::cpu_features::is_enabled_avx2() { - return unsafe { avx2::compare256(src0, src1) }; - } + #[cfg(target_feature = "avx2")] + return avx2::compare256(src0, src1); - #[cfg(target_arch = "aarch64")] - if crate::cpu_features::is_enabled_neon() { - return unsafe { neon::compare256(src0, src1) }; - } + #[cfg(target_feature = "neon")] + return neon::compare256(src0, src1); - #[cfg(target_arch = "wasm32")] - if crate::cpu_features::is_enabled_simd128() { - return wasm32::compare256(src0, src1); + #[cfg(target_feature = "simd128")] + return wasm32::compare256(src0, src1); + + #[allow(unreachable_code)] + compare256_via_function_pointer(src0, src1) +} + +/// Choose the most optimal implementation at runtime +/// +/// We store the function pointer to the most optimal implementation in an AtomicPtr; every call +/// loads this function pointer and then calls it. +/// +/// The value is initially set to `initializer`, which on the first call will determine what the +/// most efficient implementation is, and overwrite the value in the atomic, so that on subsequent +/// calls the best implementation is called immediately. +#[inline(always)] +fn compare256_via_function_pointer(src0: &[u8; 256], src1: &[u8; 256]) -> usize { + use core::sync::atomic::{AtomicPtr, Ordering}; + + type F = unsafe fn(&[u8; 256], &[u8; 256]) -> usize; + + static PTR: AtomicPtr<()> = AtomicPtr::new(initializer as *mut ()); + + fn initializer(src0: &[u8; 256], src1: &[u8; 256]) -> usize { + let ptr = match () { + #[cfg(target_arch = "x86_64")] + _ if crate::cpu_features::is_enabled_avx2() => avx2::compare256 as F, + #[cfg(target_arch = "aarch64")] + _ if crate::cpu_features::is_enabled_neon() => neon::compare256 as F, + #[cfg(target_arch = "wasm32")] + _ if crate::cpu_features::is_enabled_simd128() => wasm32::compare256 as F, + _ => rust::compare256 as F, + }; + + PTR.store(ptr as *mut (), Ordering::Relaxed); + + // Safety: we've validated the target feature requirements + unsafe { ptr(src0, src1) } } - rust::compare256(src0, src1) + let ptr = PTR.load(Ordering::Relaxed); + + // Safety: we trust this function pointer (PTR is local to the function) + let dynamic_compare256 = unsafe { core::mem::transmute::<*mut (), F>(ptr) }; + + // Safety: we've validated the target feature requirements + unsafe { dynamic_compare256(src0, src1) } } pub fn compare256_rle_slice(byte: u8, src: &[u8]) -> usize { From 4c71540c40a5d6da273bfd2da53a299903ac4bf5 Mon Sep 17 00:00:00 2001 From: Folkert de Vries Date: Tue, 24 Dec 2024 18:27:45 +0100 Subject: [PATCH 2/2] simplify quick match matching --- zlib-rs/src/deflate/algorithm/quick.rs | 19 +++++++------------ zlib-rs/src/deflate/compare256.rs | 6 +++--- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/zlib-rs/src/deflate/algorithm/quick.rs b/zlib-rs/src/deflate/algorithm/quick.rs index fd9e21cd..93c1dbc2 100644 --- a/zlib-rs/src/deflate/algorithm/quick.rs +++ b/zlib-rs/src/deflate/algorithm/quick.rs @@ -97,20 +97,15 @@ pub fn deflate_quick(stream: &mut DeflateStream, flush: DeflateFlush) -> BlockSt let dist = state.strstart as isize - hash_head as isize; if dist <= state.max_dist() as isize && dist > 0 { - let str_start = &state.window.filled()[state.strstart..]; - let match_start = &state.window.filled()[hash_head as usize..]; + let str_start = &state.window.filled()[state.strstart..][..258]; + let match_start = &state.window.filled()[hash_head as usize..][..258]; - macro_rules! first_two_bytes { - ($slice:expr, $offset:expr) => { - u16::from_le_bytes($slice[$offset..$offset + 2].try_into().unwrap()) - }; - } + let (prefix1, tail1) = str_start.split_at(2); + let (prefix2, tail2) = match_start.split_at(2); - if first_two_bytes!(str_start, 0) == first_two_bytes!(match_start, 0) { - let mut match_len = crate::deflate::compare256::compare256_slice( - &str_start[2..], - &match_start[2..], - ) + 2; + if prefix1 == prefix2 { + let mut match_len = + 2 + crate::deflate::compare256::compare256_slice(tail1, tail2); if match_len >= WANT_MIN_MATCH { match_len = Ord::min(match_len, state.lookahead); diff --git a/zlib-rs/src/deflate/compare256.rs b/zlib-rs/src/deflate/compare256.rs index ace564c3..19e77621 100644 --- a/zlib-rs/src/deflate/compare256.rs +++ b/zlib-rs/src/deflate/compare256.rs @@ -17,13 +17,13 @@ pub fn compare256_slice(src0: &[u8], src1: &[u8]) -> usize { #[inline(always)] fn compare256(src0: &[u8; 256], src1: &[u8; 256]) -> usize { #[cfg(target_feature = "avx2")] - return avx2::compare256(src0, src1); + return unsafe { avx2::compare256(src0, src1) }; #[cfg(target_feature = "neon")] - return neon::compare256(src0, src1); + return unsafe { neon::compare256(src0, src1) }; #[cfg(target_feature = "simd128")] - return wasm32::compare256(src0, src1); + return unsafe { wasm32::compare256(src0, src1) }; #[allow(unreachable_code)] compare256_via_function_pointer(src0, src1)