Skip to content

Commit 7aaba0f

Browse files
authored
Merge pull request #67 from QuState/refactor-dit-planner
Refactor dit planner
2 parents 964dd4c + da3ceec commit 7aaba0f

File tree

1 file changed

+47
-87
lines changed

1 file changed

+47
-87
lines changed

src/planner.rs

Lines changed: 47 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -76,102 +76,62 @@ macro_rules! impl_planner_for {
7676
impl_planner_for!(Planner64, f64, generate_twiddles_simd_64);
7777
impl_planner_for!(Planner32, f32, generate_twiddles_simd_32);
7878

79-
/// DIT-specific planner for f64 that pre-computes twiddles for all stages
80-
pub struct PlannerDit64 {
81-
/// Twiddles for each stage that needs them (stages with chunk_size > 32)
82-
/// Each element contains (twiddles_re, twiddles_im) for that stage
83-
pub stage_twiddles: Vec<(Vec<f64>, Vec<f64>)>,
84-
/// The direction of the FFT
85-
pub direction: Direction,
86-
/// The log2 of the FFT size
87-
pub log_n: usize,
88-
}
89-
90-
impl PlannerDit64 {
91-
/// Create a DIT planner for an FFT of size `num_points`
92-
pub fn new(num_points: usize, direction: Direction) -> Self {
93-
assert!(num_points > 0 && num_points.is_power_of_two());
94-
95-
let log_n = num_points.ilog2() as usize;
96-
let mut stage_twiddles = Vec::new();
97-
98-
// Pre-compute twiddles for each stage that needs them
99-
for stage in 0..log_n {
100-
let dist = 1 << stage; // 2.pow(stage)
101-
let chunk_size = dist * 2;
102-
103-
// Only stages with chunk_size > 64 need twiddles (we have SIMD kernels up to 64)
104-
if chunk_size > 64 {
105-
let mut twiddles_re = vec![0.0f64; dist];
106-
let mut twiddles_im = vec![0.0f64; dist];
107-
108-
let angle_mult = -2.0 * std::f64::consts::PI / chunk_size as f64;
109-
for k in 0..dist {
110-
let angle = angle_mult * k as f64;
111-
twiddles_re[k] = angle.cos();
112-
twiddles_im[k] = angle.sin();
113-
}
114-
115-
stage_twiddles.push((twiddles_re, twiddles_im));
116-
}
117-
}
118-
119-
Self {
120-
stage_twiddles,
121-
direction,
122-
log_n,
79+
macro_rules! impl_planner_dit_for {
80+
($struct_name:ident, $precision:ident) => {
81+
/// DIT-specific planner that pre-computes twiddles for all stages
82+
pub struct $struct_name {
83+
/// Twiddles for each stage that needs them (stages with chunk_size > 64)
84+
/// Each element contains (twiddles_re, twiddles_im) for that stage
85+
pub stage_twiddles: Vec<(Vec<$precision>, Vec<$precision>)>,
86+
/// The direction of the FFT
87+
pub direction: Direction,
88+
/// The log2 of the FFT size
89+
pub log_n: usize,
12390
}
124-
}
125-
}
12691

127-
/// DIT-specific planner for f32 that pre-computes twiddles for all stages
128-
pub struct PlannerDit32 {
129-
/// Twiddles for each stage that needs them (stages with chunk_size > 32)
130-
/// Each element contains (twiddles_re, twiddles_im) for that stage
131-
pub stage_twiddles: Vec<(Vec<f32>, Vec<f32>)>,
132-
/// The direction of the FFT
133-
pub direction: Direction,
134-
/// The log2 of the FFT size
135-
pub log_n: usize,
136-
}
92+
impl $struct_name {
93+
/// Create a DIT planner for an FFT of size `num_points`
94+
pub fn new(num_points: usize, direction: Direction) -> Self {
95+
assert!(num_points > 0 && num_points.is_power_of_two());
13796

138-
impl PlannerDit32 {
139-
/// Create a DIT planner for an FFT of size `num_points`
140-
pub fn new(num_points: usize, direction: Direction) -> Self {
141-
assert!(num_points > 0 && num_points.is_power_of_two());
142-
143-
let log_n = num_points.ilog2() as usize;
144-
let mut stage_twiddles = Vec::new();
145-
146-
// Pre-compute twiddles for each stage that needs them
147-
for stage in 0..log_n {
148-
let dist = 1 << stage; // 2.pow(stage)
149-
let chunk_size = dist * 2;
150-
151-
// Only stages with chunk_size > 64 need twiddles (we have SIMD kernels up to 64)
152-
if chunk_size > 64 {
153-
let mut twiddles_re = vec![0.0f32; dist];
154-
let mut twiddles_im = vec![0.0f32; dist];
155-
156-
let angle_mult = -2.0 * std::f32::consts::PI / chunk_size as f32;
157-
for k in 0..dist {
158-
let angle = angle_mult * k as f32;
159-
twiddles_re[k] = angle.cos();
160-
twiddles_im[k] = angle.sin();
97+
let log_n = num_points.ilog2() as usize;
98+
let mut stage_twiddles = Vec::new();
99+
100+
// Pre-compute twiddles for each stage that needs them
101+
for stage in 0..log_n {
102+
let dist = 1 << stage; // 2.pow(stage)
103+
let chunk_size = dist * 2;
104+
105+
// Only stages with chunk_size > 64 need twiddles (we have SIMD kernels up to 64)
106+
if chunk_size > 64 {
107+
let mut twiddles_re = vec![0.0 as $precision; dist];
108+
let mut twiddles_im = vec![0.0 as $precision; dist];
109+
110+
let angle_mult =
111+
-2.0 * std::$precision::consts::PI / chunk_size as $precision;
112+
for k in 0..dist {
113+
let angle = angle_mult * k as $precision;
114+
twiddles_re[k] = angle.cos();
115+
twiddles_im[k] = angle.sin();
116+
}
117+
118+
stage_twiddles.push((twiddles_re, twiddles_im));
119+
}
161120
}
162121

163-
stage_twiddles.push((twiddles_re, twiddles_im));
122+
Self {
123+
stage_twiddles,
124+
direction,
125+
log_n,
126+
}
164127
}
165128
}
166-
167-
Self {
168-
stage_twiddles,
169-
direction,
170-
log_n,
171-
}
172-
}
129+
};
173130
}
174131

132+
impl_planner_dit_for!(PlannerDit64, f64);
133+
impl_planner_dit_for!(PlannerDit32, f32);
134+
175135
#[cfg(test)]
176136
mod tests {
177137
use super::*;

0 commit comments

Comments
 (0)