Skip to content

Commit 7dc39b1

Browse files
authored
Merge pull request #48 from QuState/feature/cache-blocked-dit-fft
perf: add initial cache-blocked dit fft impl
2 parents 2e67b5c + 790861d commit 7dc39b1

File tree

1 file changed

+180
-50
lines changed

1 file changed

+180
-50
lines changed

src/algorithms/dit.rs

Lines changed: 180 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,184 @@ use crate::kernels::dit::{
2525
use crate::options::Options;
2626
use crate::planner::{Direction, PlannerDit32, PlannerDit64};
2727

28+
/// L1 cache block size in complex elements (8KB for f32, 16KB for f64)
29+
const L1_BLOCK_SIZE: usize = 1024;
30+
31+
/// Recursive cache-blocked DIT FFT for f64 using post-order traversal.
32+
///
33+
/// Recursively divides by 2 until reaching L1 cache size, processes stages within
34+
/// each block, then processes cross-block stages on return.
35+
fn recursive_dit_fft_f64(
36+
reals: &mut [f64],
37+
imags: &mut [f64],
38+
offset: usize,
39+
size: usize,
40+
planner: &PlannerDit64,
41+
mut stage_twiddle_idx: usize,
42+
) -> usize {
43+
let log_size = size.ilog2() as usize;
44+
45+
if size <= L1_BLOCK_SIZE {
46+
for stage in 0..log_size {
47+
stage_twiddle_idx = execute_dit_stage_f64(
48+
&mut reals[offset..offset + size],
49+
&mut imags[offset..offset + size],
50+
stage,
51+
planner,
52+
stage_twiddle_idx,
53+
);
54+
}
55+
stage_twiddle_idx
56+
} else {
57+
let half = size / 2;
58+
let log_half = half.ilog2() as usize;
59+
60+
// Recursively process both halves
61+
recursive_dit_fft_f64(reals, imags, offset, half, planner, 0);
62+
recursive_dit_fft_f64(reals, imags, offset + half, half, planner, 0);
63+
64+
// Both halves completed stages 0..log_half-1
65+
// Stages 0-5 use hardcoded twiddles, 6+ use planner
66+
stage_twiddle_idx = log_half.saturating_sub(6);
67+
68+
// Process remaining stages that span both halves
69+
for stage in log_half..log_size {
70+
stage_twiddle_idx = execute_dit_stage_f64(
71+
&mut reals[offset..offset + size],
72+
&mut imags[offset..offset + size],
73+
stage,
74+
planner,
75+
stage_twiddle_idx,
76+
);
77+
}
78+
79+
stage_twiddle_idx
80+
}
81+
}
82+
83+
/// Recursive cache-blocked DIT FFT for f32 using post-order traversal.
84+
fn recursive_dit_fft_f32(
85+
reals: &mut [f32],
86+
imags: &mut [f32],
87+
offset: usize,
88+
size: usize,
89+
planner: &PlannerDit32,
90+
mut stage_twiddle_idx: usize,
91+
) -> usize {
92+
let log_size = size.ilog2() as usize;
93+
94+
if size <= L1_BLOCK_SIZE {
95+
for stage in 0..log_size {
96+
stage_twiddle_idx = execute_dit_stage_f32(
97+
&mut reals[offset..offset + size],
98+
&mut imags[offset..offset + size],
99+
stage,
100+
planner,
101+
stage_twiddle_idx,
102+
);
103+
}
104+
stage_twiddle_idx
105+
} else {
106+
let half = size / 2;
107+
let log_half = half.ilog2() as usize;
108+
109+
recursive_dit_fft_f32(reals, imags, offset, half, planner, 0);
110+
recursive_dit_fft_f32(reals, imags, offset + half, half, planner, 0);
111+
112+
stage_twiddle_idx = log_half.saturating_sub(6);
113+
114+
for stage in log_half..log_size {
115+
stage_twiddle_idx = execute_dit_stage_f32(
116+
&mut reals[offset..offset + size],
117+
&mut imags[offset..offset + size],
118+
stage,
119+
planner,
120+
stage_twiddle_idx,
121+
);
122+
}
123+
124+
stage_twiddle_idx
125+
}
126+
}
127+
128+
/// Execute a single DIT stage, dispatching to appropriate kernel based on chunk size.
129+
/// Returns updated stage_twiddle_idx.
130+
#[inline]
131+
fn execute_dit_stage_f64(
132+
reals: &mut [f64],
133+
imags: &mut [f64],
134+
stage: usize,
135+
planner: &PlannerDit64,
136+
stage_twiddle_idx: usize,
137+
) -> usize {
138+
let dist = 1 << stage;
139+
let chunk_size = dist << 1;
140+
141+
if chunk_size == 2 {
142+
fft_dit_chunk_2(reals, imags);
143+
stage_twiddle_idx
144+
} else if chunk_size == 4 {
145+
fft_dit_chunk_4_simd_f64(reals, imags);
146+
stage_twiddle_idx
147+
} else if chunk_size == 8 {
148+
fft_dit_chunk_8_simd_f64(reals, imags);
149+
stage_twiddle_idx
150+
} else if chunk_size == 16 {
151+
fft_dit_chunk_16_simd_f64(reals, imags);
152+
stage_twiddle_idx
153+
} else if chunk_size == 32 {
154+
fft_dit_chunk_32_simd_f64(reals, imags);
155+
stage_twiddle_idx
156+
} else if chunk_size == 64 {
157+
fft_dit_chunk_64_simd_f64(reals, imags);
158+
stage_twiddle_idx
159+
} else {
160+
// For larger chunks, use general kernel with twiddles from planner
161+
let (twiddles_re, twiddles_im) = &planner.stage_twiddles[stage_twiddle_idx];
162+
fft_dit_64_chunk_n_simd(reals, imags, twiddles_re, twiddles_im, dist);
163+
stage_twiddle_idx + 1
164+
}
165+
}
166+
167+
/// Execute a single DIT stage, dispatching to appropriate kernel based on chunk size.
168+
/// Returns updated stage_twiddle_idx.
169+
#[inline]
170+
fn execute_dit_stage_f32(
171+
reals: &mut [f32],
172+
imags: &mut [f32],
173+
stage: usize,
174+
planner: &PlannerDit32,
175+
stage_twiddle_idx: usize,
176+
) -> usize {
177+
let dist = 1 << stage;
178+
let chunk_size = dist << 1;
179+
180+
if chunk_size == 2 {
181+
fft_dit_chunk_2(reals, imags);
182+
stage_twiddle_idx
183+
} else if chunk_size == 4 {
184+
fft_dit_chunk_4_simd_f32(reals, imags);
185+
stage_twiddle_idx
186+
} else if chunk_size == 8 {
187+
fft_dit_chunk_8_simd_f32(reals, imags);
188+
stage_twiddle_idx
189+
} else if chunk_size == 16 {
190+
fft_dit_chunk_16_simd_f32(reals, imags);
191+
stage_twiddle_idx
192+
} else if chunk_size == 32 {
193+
fft_dit_chunk_32_simd_f32(reals, imags);
194+
stage_twiddle_idx
195+
} else if chunk_size == 64 {
196+
fft_dit_chunk_64_simd_f32(reals, imags);
197+
stage_twiddle_idx
198+
} else {
199+
// For larger chunks, use general kernel with twiddles from planner
200+
let (twiddles_re, twiddles_im) = &planner.stage_twiddles[stage_twiddle_idx];
201+
fft_dit_32_chunk_n_simd(reals, imags, twiddles_re, twiddles_im, dist);
202+
stage_twiddle_idx + 1
203+
}
204+
}
205+
28206
/// DIT FFT for f64 with pre-computed planner and options
29207
///
30208
/// This implementation uses the Decimation-in-Time algorithm which:
@@ -74,31 +252,7 @@ pub fn fft_64_dit_with_planner_and_opts(
74252
}
75253
}
76254

77-
// DIT processes from small to large butterflies
78-
let mut stage_twiddle_idx = 0;
79-
for stage in 0..log_n {
80-
let dist = 1 << stage;
81-
let chunk_size = dist << 1;
82-
83-
if chunk_size == 2 {
84-
fft_dit_chunk_2(reals, imags);
85-
} else if chunk_size == 4 {
86-
fft_dit_chunk_4_simd_f64(reals, imags);
87-
} else if chunk_size == 8 {
88-
fft_dit_chunk_8_simd_f64(reals, imags);
89-
} else if chunk_size == 16 {
90-
fft_dit_chunk_16_simd_f64(reals, imags);
91-
} else if chunk_size == 32 {
92-
fft_dit_chunk_32_simd_f64(reals, imags);
93-
} else if chunk_size == 64 {
94-
fft_dit_chunk_64_simd_f64(reals, imags);
95-
} else {
96-
// For larger chunks, use general kernel with twiddles from planner
97-
let (twiddles_re, twiddles_im) = &planner.stage_twiddles[stage_twiddle_idx];
98-
fft_dit_64_chunk_n_simd(reals, imags, twiddles_re, twiddles_im, dist);
99-
stage_twiddle_idx += 1;
100-
}
101-
}
255+
recursive_dit_fft_f64(reals, imags, 0, n, planner, 0);
102256

103257
// Scaling for inverse transform
104258
if let Direction::Reverse = planner.direction {
@@ -145,31 +299,7 @@ pub fn fft_32_dit_with_planner_and_opts(
145299
}
146300
}
147301

148-
// DIT processes from small to large butterflies
149-
let mut stage_twiddle_idx = 0;
150-
for stage in 0..log_n {
151-
let dist = 1 << stage;
152-
let chunk_size = dist << 1;
153-
154-
if chunk_size == 2 {
155-
fft_dit_chunk_2(reals, imags);
156-
} else if chunk_size == 4 {
157-
fft_dit_chunk_4_simd_f32(reals, imags);
158-
} else if chunk_size == 8 {
159-
fft_dit_chunk_8_simd_f32(reals, imags);
160-
} else if chunk_size == 16 {
161-
fft_dit_chunk_16_simd_f32(reals, imags);
162-
} else if chunk_size == 32 {
163-
fft_dit_chunk_32_simd_f32(reals, imags);
164-
} else if chunk_size == 64 {
165-
fft_dit_chunk_64_simd_f32(reals, imags);
166-
} else {
167-
// For larger chunks, use general kernel with twiddles from planner
168-
let (twiddles_re, twiddles_im) = &planner.stage_twiddles[stage_twiddle_idx];
169-
fft_dit_32_chunk_n_simd(reals, imags, twiddles_re, twiddles_im, dist);
170-
stage_twiddle_idx += 1;
171-
}
172-
}
302+
recursive_dit_fft_f32(reals, imags, 0, n, planner, 0);
173303

174304
// Scaling for inverse transform
175305
if let Direction::Reverse = planner.direction {

0 commit comments

Comments
 (0)