Skip to content

Commit d645ed7

Browse files
committed
add vpclmulqdq implementation
1 parent 865137a commit d645ed7

File tree

5 files changed

+160
-17
lines changed

5 files changed

+160
-17
lines changed

.github/workflows/checks.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,10 @@ jobs:
480480
run: "cargo +nightly miri nextest run -j4 -p test-libz-rs-sys --target ${{ matrix.target }} null::"
481481
env:
482482
RUSTFLAGS: "-Ctarget-feature=+avx2,+bmi2,+bmi1"
483+
- name: Test avx512 crc32 implementation
484+
run: "cargo +nightly miri nextest run -j4 -p zlib-rs --target ${{ matrix.target }} --features=vpclmulqdq crc32::"
485+
env:
486+
RUSTFLAGS: "-Ctarget-feature=+vpclmulqdq,+avx512f"
483487
- name: Test allocator with miri
484488
run: "cargo +nightly miri nextest run -j4 -p zlib-rs --target ${{ matrix.target }} allocate::"
485489
- name: Test gz logic with miri

zlib-rs/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ __internal-fuzz = ["arbitrary"]
2424
__internal-fuzz-disable-checksum = [] # disable checksum validation on inflate
2525
__internal-test = ["quickcheck"]
2626
ZLIB_DEBUG = []
27+
vpclmulqdq = [] # use avx512 to speed up crc32. Only stable from 1.89.0 onwards
2728

2829

2930
[dependencies]

zlib-rs/src/crc32.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ mod braid;
66
mod combine;
77
#[cfg(target_arch = "x86_64")]
88
mod pclmulqdq;
9+
#[cfg(target_arch = "x86_64")]
10+
#[cfg(feature = "vpclmulqdq")]
11+
mod vpclmulqdq;
912

1013
pub use combine::crc32_combine;
1114

zlib-rs/src/crc32/pclmulqdq.rs

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
use core::arch::x86_64::__m128i;
21
use core::arch::x86_64::{
3-
_mm_and_si128, _mm_clmulepi64_si128, _mm_cvtsi32_si128, _mm_extract_epi32, _mm_load_si128,
2+
__m128i, _mm_and_si128, _mm_clmulepi64_si128, _mm_extract_epi32, _mm_load_si128,
43
_mm_loadu_si128, _mm_or_si128, _mm_shuffle_epi8, _mm_slli_si128, _mm_srli_si128,
54
_mm_storeu_si128, _mm_xor_si128,
65
};
@@ -24,7 +23,7 @@ const fn reg(input: [u32; 4]) -> __m128i {
2423
#[derive(Debug, Clone, Copy)]
2524
#[cfg(target_arch = "x86_64")]
2625
pub(crate) struct Accumulator {
27-
fold: [__m128i; 4],
26+
pub(super) fold: [__m128i; 4],
2827
}
2928

3029
#[cfg(target_arch = "x86_64")]
@@ -249,7 +248,6 @@ impl Accumulator {
249248
// bytes of input is needed for the aligning load that occurs. If there's an initial CRC, to
250249
// carry it forward through the folded CRC there must be 16 - src % 16 + 16 bytes available, which
251250
// by definition can be up to 15 bytes + one full vector load. */
252-
let xmm_initial = _mm_cvtsi32_si128(init_crc as i32);
253251
let first = init_crc != CRC32_INITIAL_VALUE;
254252
assert!(src.len() >= 31 || !first);
255253

@@ -281,6 +279,7 @@ impl Accumulator {
281279
let is_initial = init_crc == CRC32_INITIAL_VALUE;
282280

283281
if !is_initial {
282+
let xmm_initial = reg([init_crc, 0, 0, 0]);
284283
xmm_crc_part = unsafe { _mm_xor_si128(xmm_crc_part, xmm_initial) };
285284
init_crc = CRC32_INITIAL_VALUE;
286285
}
@@ -302,19 +301,17 @@ impl Accumulator {
302301
src = &src[before.len()..];
303302
}
304303

305-
// if is_x86_feature_detected!("vpclmulqdq") {
306-
// if src.len() >= 256 {
307-
// if COPY {
308-
// // size_t n = fold_16_vpclmulqdq_copy(&xmm_crc0, &xmm_crc1, &xmm_crc2, &xmm_crc3, dst, src, len);
309-
// // dst += n;
310-
// } else {
311-
// // size_t n = fold_16_vpclmulqdq(&xmm_crc0, &xmm_crc1, &xmm_crc2, &xmm_crc3, src, len, xmm_initial, first);
312-
// // first = false;
313-
// }
314-
// // len -= n;
315-
// // src += n;
316-
// }
317-
// }
304+
#[cfg(feature = "vpclmulqdq")]
305+
#[cfg(all(target_feature = "vpclmulqdq", target_feature = "avx512f"))]
306+
if src.len() >= 256 {
307+
let n;
308+
if COPY {
309+
n = unsafe { self.fold_16_vpclmulqdq_copy(dst, &mut src) };
310+
dst = &mut dst[n..];
311+
} else {
312+
unsafe { self.fold_16_vpclmulqdq(dst, &mut src, &mut init_crc) };
313+
}
314+
}
318315

319316
while src.len() >= 64 {
320317
let n = unsafe { self.progress::<4, COPY>(dst, &mut src, &mut init_crc) };

zlib-rs/src/crc32/vpclmulqdq.rs

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
use crate::CRC32_INITIAL_VALUE;
2+
use core::arch::x86_64::{
3+
__m128i, __m512i, _mm512_clmulepi64_epi128, _mm512_extracti32x4_epi32, _mm512_inserti32x4,
4+
_mm512_loadu_si512, _mm512_set4_epi32, _mm512_setzero_si512, _mm512_storeu_si512,
5+
_mm512_ternarylogic_epi32, _mm512_xor_si512, _mm512_zextsi128_si512, _mm_cvtsi32_si128,
6+
};
7+
8+
impl super::pclmulqdq::Accumulator {
9+
#[target_feature(enable = "vpclmulqdq", enable = "avx512f")]
10+
pub(super) unsafe fn fold_16_vpclmulqdq(
11+
&mut self,
12+
dst: &mut [u8],
13+
src: &mut &[u8],
14+
init_crc: &mut u32,
15+
) -> usize {
16+
unsafe { self.fold_help_vpclmulqdq::<false>(dst, src, init_crc) }
17+
}
18+
19+
#[target_feature(enable = "vpclmulqdq", enable = "avx512f")]
20+
pub(super) unsafe fn fold_16_vpclmulqdq_copy(
21+
&mut self,
22+
dst: &mut [u8],
23+
src: &mut &[u8],
24+
) -> usize {
25+
unsafe { self.fold_help_vpclmulqdq::<true>(dst, src, &mut CRC32_INITIAL_VALUE) }
26+
}
27+
28+
#[target_feature(enable = "vpclmulqdq", enable = "avx512f")]
29+
unsafe fn fold_help_vpclmulqdq<const COPY: bool>(
30+
&mut self,
31+
mut dst: &mut [u8],
32+
src: &mut &[u8],
33+
init_crc: &mut u32,
34+
) -> usize {
35+
let [xmm_crc0, xmm_crc1, xmm_crc2, xmm_crc3] = &mut self.fold;
36+
let start_len = src.len();
37+
38+
unsafe {
39+
let len_tmp = src.len();
40+
let zmm_fold4 =
41+
_mm512_set4_epi32(0x00000001, 0x54442bd4, 0x00000001, 0xc6e41596u32 as i32);
42+
let zmm_fold16 = _mm512_set4_epi32(0x00000001, 0x1542778a, 0x00000001, 0x322d1430);
43+
44+
// zmm register init
45+
let zmm_crc0 = _mm512_setzero_si512();
46+
let mut zmm_t0 = _mm512_loadu_si512(src.as_ptr().cast::<__m512i>());
47+
48+
if !COPY && *init_crc != CRC32_INITIAL_VALUE {
49+
let xmm_initial = _mm_cvtsi32_si128(*init_crc as i32);
50+
let zmm_initial = _mm512_zextsi128_si512(xmm_initial);
51+
zmm_t0 = _mm512_xor_si512(zmm_t0, zmm_initial);
52+
*init_crc = CRC32_INITIAL_VALUE;
53+
}
54+
55+
let mut zmm_crc1 = _mm512_loadu_si512(src.as_ptr().cast::<__m512i>().add(1));
56+
let mut zmm_crc2 = _mm512_loadu_si512(src.as_ptr().cast::<__m512i>().add(2));
57+
let mut zmm_crc3 = _mm512_loadu_si512(src.as_ptr().cast::<__m512i>().add(3));
58+
59+
/* already have intermediate CRC in xmm registers
60+
* fold4 with 4 xmm_crc to get zmm_crc0
61+
*/
62+
let mut zmm_crc0 = _mm512_inserti32x4(zmm_crc0, *xmm_crc0, 0);
63+
zmm_crc0 = _mm512_inserti32x4(zmm_crc0, *xmm_crc1, 1);
64+
zmm_crc0 = _mm512_inserti32x4(zmm_crc0, *xmm_crc2, 2);
65+
zmm_crc0 = _mm512_inserti32x4(zmm_crc0, *xmm_crc3, 3);
66+
let mut z0 = _mm512_clmulepi64_epi128(zmm_crc0, zmm_fold4, 0x01);
67+
zmm_crc0 = _mm512_clmulepi64_epi128(zmm_crc0, zmm_fold4, 0x10);
68+
dbg!(zmm_crc0, z0, zmm_t0, 0x96);
69+
zmm_crc0 = _mm512_ternarylogic_epi32(zmm_crc0, z0, zmm_t0, 0x96);
70+
dbg!(zmm_crc0);
71+
72+
if COPY {
73+
_mm512_storeu_si512(dst.as_mut_ptr().cast::<__m512i>(), zmm_t0);
74+
_mm512_storeu_si512(dst.as_mut_ptr().cast::<__m512i>().add(1), zmm_crc1);
75+
_mm512_storeu_si512(dst.as_mut_ptr().cast::<__m512i>().add(2), zmm_crc2);
76+
_mm512_storeu_si512(dst.as_mut_ptr().cast::<__m512i>().add(3), zmm_crc3);
77+
dst = &mut dst[256..];
78+
}
79+
80+
*src = &src[256..];
81+
82+
// fold-16 loops
83+
while src.len() >= 256 {
84+
let zmm_t0 = _mm512_loadu_si512(src.as_ptr().cast::<__m512i>());
85+
let zmm_t1 = _mm512_loadu_si512(src.as_ptr().cast::<__m512i>().add(1));
86+
let zmm_t2 = _mm512_loadu_si512(src.as_ptr().cast::<__m512i>().add(2));
87+
let zmm_t3 = _mm512_loadu_si512(src.as_ptr().cast::<__m512i>().add(3));
88+
89+
let z0 = _mm512_clmulepi64_epi128(zmm_crc0, zmm_fold16, 0x01);
90+
let z1 = _mm512_clmulepi64_epi128(zmm_crc1, zmm_fold16, 0x01);
91+
let z2 = _mm512_clmulepi64_epi128(zmm_crc2, zmm_fold16, 0x01);
92+
let z3 = _mm512_clmulepi64_epi128(zmm_crc3, zmm_fold16, 0x01);
93+
94+
zmm_crc0 = _mm512_clmulepi64_epi128(zmm_crc0, zmm_fold16, 0x10);
95+
zmm_crc1 = _mm512_clmulepi64_epi128(zmm_crc1, zmm_fold16, 0x10);
96+
zmm_crc2 = _mm512_clmulepi64_epi128(zmm_crc2, zmm_fold16, 0x10);
97+
zmm_crc3 = _mm512_clmulepi64_epi128(zmm_crc3, zmm_fold16, 0x10);
98+
99+
zmm_crc0 = _mm512_ternarylogic_epi32(zmm_crc0, z0, zmm_t0, 0x96);
100+
zmm_crc1 = _mm512_ternarylogic_epi32(zmm_crc1, z1, zmm_t1, 0x96);
101+
zmm_crc2 = _mm512_ternarylogic_epi32(zmm_crc2, z2, zmm_t2, 0x96);
102+
zmm_crc3 = _mm512_ternarylogic_epi32(zmm_crc3, z3, zmm_t3, 0x96);
103+
104+
if COPY {
105+
_mm512_storeu_si512(dst.as_mut_ptr().cast::<__m512i>(), zmm_t0);
106+
_mm512_storeu_si512(dst.as_mut_ptr().cast::<__m512i>().add(1), zmm_t1);
107+
_mm512_storeu_si512(dst.as_mut_ptr().cast::<__m512i>().add(2), zmm_t2);
108+
_mm512_storeu_si512(dst.as_mut_ptr().cast::<__m512i>().add(3), zmm_t3);
109+
dst = &mut dst[256..];
110+
}
111+
112+
*src = &src[256..];
113+
}
114+
115+
// zmm_crc[0,1,2,3] -> zmm_crc0
116+
z0 = _mm512_clmulepi64_epi128(zmm_crc0, zmm_fold4, 0x01);
117+
zmm_crc0 = _mm512_clmulepi64_epi128(zmm_crc0, zmm_fold4, 0x10);
118+
zmm_crc0 = _mm512_ternarylogic_epi32(zmm_crc0, z0, zmm_crc1, 0x96);
119+
120+
z0 = _mm512_clmulepi64_epi128(zmm_crc0, zmm_fold4, 0x01);
121+
zmm_crc0 = _mm512_clmulepi64_epi128(zmm_crc0, zmm_fold4, 0x10);
122+
zmm_crc0 = _mm512_ternarylogic_epi32(zmm_crc0, z0, zmm_crc2, 0x96);
123+
124+
z0 = _mm512_clmulepi64_epi128(zmm_crc0, zmm_fold4, 0x01);
125+
zmm_crc0 = _mm512_clmulepi64_epi128(zmm_crc0, zmm_fold4, 0x10);
126+
zmm_crc0 = _mm512_ternarylogic_epi32(zmm_crc0, z0, zmm_crc3, 0x96);
127+
128+
// zmm_crc0 -> xmm_crc[0, 1, 2, 3]
129+
*xmm_crc0 = _mm512_extracti32x4_epi32(zmm_crc0, 0);
130+
*xmm_crc1 = _mm512_extracti32x4_epi32(zmm_crc0, 1);
131+
*xmm_crc2 = _mm512_extracti32x4_epi32(zmm_crc0, 2);
132+
*xmm_crc3 = _mm512_extracti32x4_epi32(zmm_crc0, 3);
133+
134+
// return n bytes processed
135+
start_len - src.len()
136+
}
137+
}
138+
}

0 commit comments

Comments
 (0)