Skip to content

Commit 877ad1d

Browse files
authored
Merge pull request #51 from QuState/as-chunks
Use `.as_chunks_mut()` instead of `.chunks_exact_mut()` for better performance
2 parents 7dc39b1 + 6f6ff7d commit 877ad1d

File tree

1 file changed

+102
-92
lines changed

1 file changed

+102
-92
lines changed

src/kernels/dit.rs

Lines changed: 102 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,8 @@ pub fn fft_dit_chunk_4_simd_f64(reals: &mut [f64], imags: &mut [f64]) {
3333

3434
let two = 2.0_f64;
3535

36-
reals
37-
.chunks_exact_mut(CHUNK_SIZE)
38-
.zip(imags.chunks_exact_mut(CHUNK_SIZE))
36+
(reals.as_chunks_mut::<CHUNK_SIZE>().0.iter_mut())
37+
.zip(imags.as_chunks_mut::<CHUNK_SIZE>().0.iter_mut())
3938
.for_each(|(reals_chunk, imags_chunk)| {
4039
let (reals_s0, reals_s1) = reals_chunk.split_at_mut(DIST);
4140
let (imags_s0, imags_s1) = imags_chunk.split_at_mut(DIST);
@@ -84,9 +83,8 @@ pub fn fft_dit_chunk_4_simd_f32(reals: &mut [f32], imags: &mut [f32]) {
8483

8584
let two = 2.0_f32;
8685

87-
reals
88-
.chunks_exact_mut(CHUNK_SIZE)
89-
.zip(imags.chunks_exact_mut(CHUNK_SIZE))
86+
(reals.as_chunks_mut::<CHUNK_SIZE>().0.iter_mut())
87+
.zip(imags.as_chunks_mut::<CHUNK_SIZE>().0.iter_mut())
9088
.for_each(|(reals_chunk, imags_chunk)| {
9189
let (reals_s0, reals_s1) = reals_chunk.split_at_mut(DIST);
9290
let (imags_s0, imags_s1) = imags_chunk.split_at_mut(DIST);
@@ -147,9 +145,8 @@ pub fn fft_dit_chunk_8_simd_f64(reals: &mut [f64], imags: &mut [f64]) {
147145
-std::f64::consts::FRAC_1_SQRT_2, // W_8^3 imag (-sqrt(2)/2)
148146
]);
149147

150-
reals
151-
.chunks_exact_mut(CHUNK_SIZE)
152-
.zip(imags.chunks_exact_mut(CHUNK_SIZE))
148+
(reals.as_chunks_mut::<CHUNK_SIZE>().0.iter_mut())
149+
.zip(imags.as_chunks_mut::<CHUNK_SIZE>().0.iter_mut())
153150
.for_each(|(reals_chunk, imags_chunk)| {
154151
let (reals_s0, reals_s1) = reals_chunk.split_at_mut(DIST);
155152
let (imags_s0, imags_s1) = imags_chunk.split_at_mut(DIST);
@@ -204,9 +201,8 @@ pub fn fft_dit_chunk_8_simd_f32(reals: &mut [f32], imags: &mut [f32]) {
204201
-std::f32::consts::FRAC_1_SQRT_2, // W_8^3 imag (-sqrt(2)/2)
205202
]);
206203

207-
reals
208-
.chunks_exact_mut(CHUNK_SIZE)
209-
.zip(imags.chunks_exact_mut(CHUNK_SIZE))
204+
(reals.as_chunks_mut::<CHUNK_SIZE>().0.iter_mut())
205+
.zip(imags.as_chunks_mut::<CHUNK_SIZE>().0.iter_mut())
210206
.for_each(|(reals_chunk, imags_chunk)| {
211207
let (reals_s0, reals_s1) = reals_chunk.split_at_mut(DIST);
212208
let (imags_s0, imags_s1) = imags_chunk.split_at_mut(DIST);
@@ -272,9 +268,8 @@ pub fn fft_dit_chunk_16_simd_f64(reals: &mut [f64], imags: &mut [f64]) {
272268
-0.38268343236508984, // W_16^7 = -sin(pi/8)
273269
]);
274270

275-
reals
276-
.chunks_exact_mut(CHUNK_SIZE)
277-
.zip(imags.chunks_exact_mut(CHUNK_SIZE))
271+
(reals.as_chunks_mut::<CHUNK_SIZE>().0.iter_mut())
272+
.zip(imags.as_chunks_mut::<CHUNK_SIZE>().0.iter_mut())
278273
.for_each(|(reals_chunk, imags_chunk)| {
279274
let (reals_s0, reals_s1) = reals_chunk.split_at_mut(DIST);
280275
let (imags_s0, imags_s1) = imags_chunk.split_at_mut(DIST);
@@ -339,9 +334,8 @@ pub fn fft_dit_chunk_16_simd_f32(reals: &mut [f32], imags: &mut [f32]) {
339334
-0.382_683_43_f32, // W_16^7 = -sin(pi/8)
340335
]);
341336

342-
reals
343-
.chunks_exact_mut(CHUNK_SIZE)
344-
.zip(imags.chunks_exact_mut(CHUNK_SIZE))
337+
(reals.as_chunks_mut::<CHUNK_SIZE>().0.iter_mut())
338+
.zip(imags.as_chunks_mut::<CHUNK_SIZE>().0.iter_mut())
345339
.for_each(|(reals_chunk, imags_chunk)| {
346340
let (reals_s0, reals_s1) = reals_chunk.split_at_mut(DIST);
347341
let (imags_s0, imags_s1) = imags_chunk.split_at_mut(DIST);
@@ -428,9 +422,8 @@ pub fn fft_dit_chunk_32_simd_f64(reals: &mut [f64], imags: &mut [f64]) {
428422
-0.19509032201612825, // W_32^15
429423
]);
430424

431-
reals
432-
.chunks_exact_mut(CHUNK_SIZE)
433-
.zip(imags.chunks_exact_mut(CHUNK_SIZE))
425+
(reals.as_chunks_mut::<CHUNK_SIZE>().0.iter_mut())
426+
.zip(imags.as_chunks_mut::<CHUNK_SIZE>().0.iter_mut())
434427
.for_each(|(reals_chunk, imags_chunk)| {
435428
let (reals_s0, reals_s1) = reals_chunk.split_at_mut(DIST);
436429
let (imags_s0, imags_s1) = imags_chunk.split_at_mut(DIST);
@@ -535,9 +528,8 @@ pub fn fft_dit_chunk_32_simd_f32(reals: &mut [f32], imags: &mut [f32]) {
535528
-0.195_090_32_f32, // W_32^15
536529
]);
537530

538-
reals
539-
.chunks_exact_mut(CHUNK_SIZE)
540-
.zip(imags.chunks_exact_mut(CHUNK_SIZE))
531+
(reals.as_chunks_mut::<CHUNK_SIZE>().0.iter_mut())
532+
.zip(imags.as_chunks_mut::<CHUNK_SIZE>().0.iter_mut())
541533
.for_each(|(reals_chunk, imags_chunk)| {
542534
let (reals_s0, reals_s1) = reals_chunk.split_at_mut(DIST);
543535
let (imags_s0, imags_s1) = imags_chunk.split_at_mut(DIST);
@@ -671,9 +663,8 @@ pub fn fft_dit_chunk_64_simd_f64(reals: &mut [f64], imags: &mut [f64]) {
671663
-0.0980171403295606, // W_64^31
672664
]);
673665

674-
reals
675-
.chunks_exact_mut(CHUNK_SIZE)
676-
.zip(imags.chunks_exact_mut(CHUNK_SIZE))
666+
(reals.as_chunks_mut::<CHUNK_SIZE>().0.iter_mut())
667+
.zip(imags.as_chunks_mut::<CHUNK_SIZE>().0.iter_mut())
677668
.for_each(|(reals_chunk, imags_chunk)| {
678669
let (reals_s0, reals_s1) = reals_chunk.split_at_mut(DIST);
679670
let (imags_s0, imags_s1) = imags_chunk.split_at_mut(DIST);
@@ -846,9 +837,8 @@ pub fn fft_dit_chunk_64_simd_f32(reals: &mut [f32], imags: &mut [f32]) {
846837
-0.098_017_14_f32, // W_64^31
847838
]);
848839

849-
reals
850-
.chunks_exact_mut(CHUNK_SIZE)
851-
.zip(imags.chunks_exact_mut(CHUNK_SIZE))
840+
(reals.as_chunks_mut::<CHUNK_SIZE>().0.iter_mut())
841+
.zip(imags.as_chunks_mut::<CHUNK_SIZE>().0.iter_mut())
852842
.for_each(|(reals_chunk, imags_chunk)| {
853843
let (reals_s0, reals_s1) = reals_chunk.split_at_mut(DIST);
854844
let (imags_s0, imags_s1) = imags_chunk.split_at_mut(DIST);
@@ -910,7 +900,6 @@ pub fn fft_dit_64_chunk_n_simd(
910900
) {
911901
const LANES: usize = 8;
912902
let chunk_size = dist << 1;
913-
let two = f64x8::splat(2.0);
914903
assert!(chunk_size >= LANES * 2);
915904

916905
reals
@@ -920,39 +909,51 @@ pub fn fft_dit_64_chunk_n_simd(
920909
let (reals_s0, reals_s1) = reals_chunk.split_at_mut(dist);
921910
let (imags_s0, imags_s1) = imags_chunk.split_at_mut(dist);
922911

923-
reals_s0
924-
.chunks_exact_mut(LANES)
925-
.zip(reals_s1.chunks_exact_mut(LANES))
926-
.zip(imags_s0.chunks_exact_mut(LANES))
927-
.zip(imags_s1.chunks_exact_mut(LANES))
928-
.zip(twiddles_re.chunks_exact(LANES))
929-
.zip(twiddles_im.chunks_exact(LANES))
930-
.for_each(|(((((re_s0, re_s1), im_s0), im_s1), w_re), w_im)| {
931-
let in0_re = f64x8::new(re_s0[0..8].try_into().unwrap());
932-
let in1_re = f64x8::new(re_s1[0..8].try_into().unwrap());
933-
let in0_im = f64x8::new(im_s0[0..8].try_into().unwrap());
934-
let in1_im = f64x8::new(im_s1[0..8].try_into().unwrap());
935-
936-
let tw_re = f64x8::new(w_re[0..8].try_into().unwrap());
937-
let tw_im = f64x8::new(w_im[0..8].try_into().unwrap());
938-
939-
// out0.re = (in0.re + tw_re * in1.re) - tw_im * in1.im
940-
let out0_re = tw_im.mul_neg_add(in1_im, tw_re.mul_add(in1_re, in0_re));
941-
// out0.im = (in0.im + tw_re * in1.im) + tw_im * in1.re
942-
let out0_im = tw_im.mul_add(in1_re, tw_re.mul_add(in1_im, in0_im));
943-
944-
// Use FMA for out1 = 2*in0 - out0
945-
let out1_re = two.mul_sub(in0_re, out0_re);
946-
let out1_im = two.mul_sub(in0_im, out0_im);
947-
948-
re_s0.copy_from_slice(out0_re.as_array());
949-
im_s0.copy_from_slice(out0_im.as_array());
950-
re_s1.copy_from_slice(out1_re.as_array());
951-
im_s1.copy_from_slice(out1_im.as_array());
912+
(reals_s0.as_chunks_mut::<LANES>().0.iter_mut())
913+
.zip(reals_s1.as_chunks_mut::<LANES>().0.iter_mut())
914+
.zip(imags_s0.as_chunks_mut::<LANES>().0.iter_mut())
915+
.zip(imags_s1.as_chunks_mut::<LANES>().0.iter_mut())
916+
.zip(twiddles_re.as_chunks::<LANES>().0.iter())
917+
.zip(twiddles_im.as_chunks::<LANES>().0.iter())
918+
.for_each(|(((((re_s0, re_s1), im_s0), im_s1), tw_re), tw_im)| {
919+
fft_dit_64_chunk_n_simd_kernel(re_s0, re_s1, im_s0, im_s1, tw_re, tw_im)
952920
});
953921
});
954922
}
955923

924+
#[inline(always)] // for multiversioning to work
925+
fn fft_dit_64_chunk_n_simd_kernel(
926+
re_s0: &mut [f64; 8],
927+
re_s1: &mut [f64; 8],
928+
im_s0: &mut [f64; 8],
929+
im_s1: &mut [f64; 8],
930+
tw_re: &[f64; 8],
931+
tw_im: &[f64; 8],
932+
) {
933+
let two = f64x8::splat(2.0);
934+
let in0_re = f64x8::new(*re_s0);
935+
let in1_re = f64x8::new(*re_s1);
936+
let in0_im = f64x8::new(*im_s0);
937+
let in1_im = f64x8::new(*im_s1);
938+
939+
let tw_re = f64x8::new(*tw_re);
940+
let tw_im = f64x8::new(*tw_im);
941+
942+
// out0.re = (in0.re + tw_re * in1.re) - tw_im * in1.im
943+
let out0_re = tw_im.mul_neg_add(in1_im, tw_re.mul_add(in1_re, in0_re));
944+
// out0.im = (in0.im + tw_re * in1.im) + tw_im * in1.re
945+
let out0_im = tw_im.mul_add(in1_re, tw_re.mul_add(in1_im, in0_im));
946+
947+
// Use FMA for out1 = 2*in0 - out0
948+
let out1_re = two.mul_sub(in0_re, out0_re);
949+
let out1_im = two.mul_sub(in0_im, out0_im);
950+
951+
re_s0.copy_from_slice(out0_re.as_array());
952+
im_s0.copy_from_slice(out0_im.as_array());
953+
re_s1.copy_from_slice(out1_re.as_array());
954+
im_s1.copy_from_slice(out1_im.as_array());
955+
}
956+
956957
/// General DIT butterfly for f32
957958
#[multiversion::multiversion(targets(
958959
"x86_64+avx512f+avx512bw+avx512cd+avx512dq+avx512vl+gfni",
@@ -973,7 +974,6 @@ pub fn fft_dit_32_chunk_n_simd(
973974
) {
974975
const LANES: usize = 16;
975976
let chunk_size = dist << 1;
976-
let two = f32x16::splat(2.0);
977977
assert!(chunk_size >= LANES * 2);
978978

979979
reals
@@ -983,37 +983,47 @@ pub fn fft_dit_32_chunk_n_simd(
983983
let (reals_s0, reals_s1) = reals_chunk.split_at_mut(dist);
984984
let (imags_s0, imags_s1) = imags_chunk.split_at_mut(dist);
985985

986-
reals_s0
987-
.chunks_exact_mut(LANES)
988-
.zip(reals_s1.chunks_exact_mut(LANES))
989-
.zip(imags_s0.chunks_exact_mut(LANES))
990-
.zip(imags_s1.chunks_exact_mut(LANES))
991-
.zip(twiddles_re.chunks_exact(LANES))
992-
.zip(twiddles_im.chunks_exact(LANES))
993-
.for_each(|(((((re_s0, re_s1), im_s0), im_s1), w_re), w_im)| {
994-
let in0_re = f32x16::new(re_s0[0..16].try_into().unwrap());
995-
let in1_re = f32x16::new(re_s1[0..16].try_into().unwrap());
996-
let in0_im = f32x16::new(im_s0[0..16].try_into().unwrap());
997-
let in1_im = f32x16::new(im_s1[0..16].try_into().unwrap());
998-
999-
let tw_re = f32x16::new(w_re[0..16].try_into().unwrap());
1000-
let tw_im = f32x16::new(w_im[0..16].try_into().unwrap());
1001-
1002-
// Exactly 6 FMAs for DIT butterfly:
1003-
// tw_im contains negative values (standard twiddle convention)
1004-
// out0.re = (in0.re + tw_re*in1.re) - tw_im*in1.im
1005-
let out0_re = tw_im.mul_neg_add(in1_im, tw_re.mul_add(in1_re, in0_re));
1006-
// out0.im = (in0.im + tw_re*in1.im) + tw_im*in1.re
1007-
let out0_im = tw_im.mul_add(in1_re, tw_re.mul_add(in1_im, in0_im));
1008-
1009-
// Use FMA for out1 = 2*in0 - out0
1010-
let out1_re = two.mul_sub(in0_re, out0_re);
1011-
let out1_im = two.mul_sub(in0_im, out0_im);
1012-
1013-
re_s0.copy_from_slice(out0_re.as_array());
1014-
im_s0.copy_from_slice(out0_im.as_array());
1015-
re_s1.copy_from_slice(out1_re.as_array());
1016-
im_s1.copy_from_slice(out1_im.as_array());
986+
(reals_s0.as_chunks_mut::<LANES>().0.iter_mut())
987+
.zip(reals_s1.as_chunks_mut::<LANES>().0.iter_mut())
988+
.zip(imags_s0.as_chunks_mut::<LANES>().0.iter_mut())
989+
.zip(imags_s1.as_chunks_mut::<LANES>().0.iter_mut())
990+
.zip(twiddles_re.as_chunks::<LANES>().0.iter())
991+
.zip(twiddles_im.as_chunks::<LANES>().0.iter())
992+
.for_each(|(((((re_s0, re_s1), im_s0), im_s1), tw_re), tw_im)| {
993+
fft_dit_32_chunk_n_simd_kernel(re_s0, re_s1, im_s0, im_s1, tw_re, tw_im)
1017994
});
1018995
});
1019996
}
997+
998+
#[inline(always)] // for multiversioning to work
999+
fn fft_dit_32_chunk_n_simd_kernel(
1000+
re_s0: &mut [f32; 16],
1001+
re_s1: &mut [f32; 16],
1002+
im_s0: &mut [f32; 16],
1003+
im_s1: &mut [f32; 16],
1004+
tw_re: &[f32; 16],
1005+
tw_im: &[f32; 16],
1006+
) {
1007+
let two = f32x16::splat(2.0);
1008+
let in0_re = f32x16::new(*re_s0);
1009+
let in1_re = f32x16::new(*re_s1);
1010+
let in0_im = f32x16::new(*im_s0);
1011+
let in1_im = f32x16::new(*im_s1);
1012+
1013+
let tw_re = f32x16::new(*tw_re);
1014+
let tw_im = f32x16::new(*tw_im);
1015+
1016+
// out0.re = (in0.re + tw_re * in1.re) - tw_im * in1.im
1017+
let out0_re = tw_im.mul_neg_add(in1_im, tw_re.mul_add(in1_re, in0_re));
1018+
// out0.im = (in0.im + tw_re * in1.im) + tw_im * in1.re
1019+
let out0_im = tw_im.mul_add(in1_re, tw_re.mul_add(in1_im, in0_im));
1020+
1021+
// Use FMA for out1 = 2*in0 - out0
1022+
let out1_re = two.mul_sub(in0_re, out0_re);
1023+
let out1_im = two.mul_sub(in0_im, out0_im);
1024+
1025+
re_s0.copy_from_slice(out0_re.as_array());
1026+
im_s0.copy_from_slice(out0_im.as_array());
1027+
re_s1.copy_from_slice(out1_re.as_array());
1028+
im_s1.copy_from_slice(out1_im.as_array());
1029+
}

0 commit comments

Comments
 (0)