diff --git a/src/kernels/dit.rs b/src/kernels/dit.rs index 5c1f50e..98d0aee 100644 --- a/src/kernels/dit.rs +++ b/src/kernels/dit.rs @@ -1105,3 +1105,119 @@ fn fft_dit_chunk_n_simd_f32( }); }); } + +/// General DIT butterfly for f32 +#[inline(never)] // otherwise every kernel gets inlined into the parent and ARM perf drops due to register pressure +pub fn fft_dit_chunk_n_f32_fused( + simd: S, + reals: &mut [f32], + imags: &mut [f32], + twiddles_re: &[f32], + twiddles_im: &[f32], + twiddles_next_re: &[f32], + twiddles_next_im: &[f32], + dist: usize, +) { + simd.vectorize( + #[inline(always)] + || { + fft_dit_chunk_n_simd_f32_fused( + simd, + reals, + imags, + twiddles_re, + twiddles_im, + twiddles_next_re, + twiddles_next_im, + dist, + ) + }, + ) +} + +/// General DIT butterfly for f32 with two passses fused together +/// to do twice the work per each pass over the memory, +/// which helps performance because large FFTs are overwhelmingly memory-bottlenecked +#[inline(always)] // required by fearless_simd +fn fft_dit_chunk_n_simd_f32_fused( + simd: S, + reals: &mut [f32], + imags: &mut [f32], + twiddles_re: &[f32], + twiddles_im: &[f32], + twiddles_next_re: &[f32], + twiddles_next_im: &[f32], + dist: usize, +) { + const LANES: usize = 16; + let chunk_size = dist << 1; + assert!(chunk_size >= LANES * 2); + + reals + .chunks_exact_mut(chunk_size) + .zip(imags.chunks_exact_mut(chunk_size)) + .for_each(|(reals_chunk, imags_chunk)| { + let (reals_s0, reals_s1) = reals_chunk.split_at_mut(dist); + let (imags_s0, imags_s1) = imags_chunk.split_at_mut(dist); + + (reals_s0.as_chunks_mut::().0.iter_mut()) + .zip(reals_s1.as_chunks_mut::().0.iter_mut()) + .zip(imags_s0.as_chunks_mut::().0.iter_mut()) + .zip(imags_s1.as_chunks_mut::().0.iter_mut()) + .zip(twiddles_re.as_chunks::().0.iter()) + .zip(twiddles_im.as_chunks::().0.iter()) + .zip(twiddles_next_re.as_chunks::().0.iter()) + .zip(twiddles_next_im.as_chunks::().0.iter()) + .for_each( + |(((((((re_s0, re_s1), im_s0), im_s1), tw_re), tw_im), tw2_re), tw2_im)| { + let two = f32x16::splat(simd, 2.0); + let in0_re = f32x16::simd_from(simd, *re_s0); + let in1_re = f32x16::simd_from(simd, *re_s1); + let in0_im = f32x16::simd_from(simd, *im_s0); + let in1_im = f32x16::simd_from(simd, *im_s1); + + let tw_re = f32x16::simd_from(simd, *tw_re); + let tw_im = f32x16::simd_from(simd, *tw_im); + + // out0.re = (in0.re + tw_re * in1.re) - tw_im * in1.im + let out0_re = tw_im.mul_add(-in1_im, tw_re.mul_add(in1_re, in0_re)); + // out0.im = (in0.im + tw_re * in1.im) + tw_im * in1.re + let out0_im = tw_im.mul_add(in1_re, tw_re.mul_add(in1_im, in0_im)); + + // Use FMA for out1 = 2*in0 - out0 + let out1_re = two.mul_sub(in0_re, out0_re); + let out1_im = two.mul_sub(in0_im, out0_im); + + // repeat the operation using the output from the previous step with the new twiddles + // lets us reuse data already loaded into registers without going through another load/store + // so that instead of (6 loads + 4 FMAs + 4 stores) * 2 we get 8 loads + 8 FMAs + 4 stores + // given that modern CPUs are absurdly memory-bottlenecked + // this should help performance a lot, especially on larger sizes + + // TODO: this might be wrong. `chunk_size` changes between stages. Need to account for that. + + let in0_re = out0_re; + let in0_im = out0_im; + let in1_re = out1_re; + let in1_im = out1_im; + + let tw_re = f32x16::simd_from(simd, *tw2_re); + let tw_im = f32x16::simd_from(simd, *tw2_im); + + // out0.re = (in0.re + tw_re * in1.re) - tw_im * in1.im + let out0_re = tw_im.mul_add(-in1_im, tw_re.mul_add(in1_re, in0_re)); + // out0.im = (in0.im + tw_re * in1.im) + tw_im * in1.re + let out0_im = tw_im.mul_add(in1_re, tw_re.mul_add(in1_im, in0_im)); + + // Use FMA for out1 = 2*in0 - out0 + let out1_re = two.mul_sub(in0_re, out0_re); + let out1_im = two.mul_sub(in0_im, out0_im); + + out0_re.store_slice(re_s0); + out0_im.store_slice(im_s0); + out1_re.store_slice(re_s1); + out1_im.store_slice(im_s1); + }, + ); + }); +}