Skip to content

Commit 5a9cffd

Browse files
committed
special case for n_threads=1
1 parent c798443 commit 5a9cffd

File tree

1 file changed

+45
-12
lines changed

1 file changed

+45
-12
lines changed

src/gates/orbital_rotation.rs

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,31 +52,64 @@ pub fn apply_givens_rotation_in_place(
5252
.unwrap_or(1)
5353
})
5454
.min(n_pairs);
55+
56+
let slice1_slice = slice1.as_slice().unwrap();
57+
let slice2_slice = slice2.as_slice().unwrap();
58+
59+
// Sequential execution for single thread
60+
if n_threads == 1 {
61+
let vec_ptr = vec.as_mut_ptr();
62+
for (&i, &j) in slice1_slice.iter().zip(slice2_slice) {
63+
unsafe {
64+
_apply_givens_rotation_to_pair(
65+
vec_ptr, i, j, dim_b, dim_b_i32, c, s_abs, phase, phase_conj,
66+
);
67+
}
68+
}
69+
return;
70+
}
71+
72+
// Parallel execution
5573
let chunk_size = n_pairs.div_ceil(n_threads);
5674
let vec_ptr = vec.as_mut_ptr() as usize;
57-
let slice1 = slice1.as_slice().unwrap();
58-
let slice2 = slice2.as_slice().unwrap();
59-
6075
std::thread::scope(|scope| {
6176
for k in 0..n_threads {
6277
let start = k * chunk_size;
6378
let end = (start + chunk_size).min(n_pairs);
64-
let slice1_chunk = &slice1[start..end];
65-
let slice2_chunk = &slice2[start..end];
79+
let slice1_chunk = &slice1_slice[start..end];
80+
let slice2_chunk = &slice2_slice[start..end];
6681
scope.spawn(move || {
6782
let vec_ptr = vec_ptr as *mut Complex64;
6883
for (&i, &j) in slice1_chunk.iter().zip(slice2_chunk) {
6984
unsafe {
70-
let row_i = std::slice::from_raw_parts_mut(vec_ptr.add(i * dim_b), dim_b);
71-
let row_j = std::slice::from_raw_parts_mut(vec_ptr.add(j * dim_b), dim_b);
72-
// Use zdrot and zscal because zrot is not currently available
73-
// See https://github.com/qiskit-community/ffsim/issues/28
74-
zscal(dim_b_i32, phase_conj, row_i, 1);
75-
zdrot(dim_b_i32, row_i, 1, row_j, 1, c, s_abs);
76-
zscal(dim_b_i32, phase, row_i, 1);
85+
_apply_givens_rotation_to_pair(
86+
vec_ptr, i, j, dim_b, dim_b_i32, c, s_abs, phase, phase_conj,
87+
);
7788
}
7889
}
7990
});
8091
}
8192
});
8293
}
94+
95+
/// Apply Givens rotation to a pair of rows
96+
#[allow(clippy::too_many_arguments)]
97+
unsafe fn _apply_givens_rotation_to_pair(
98+
vec_ptr: *mut Complex64,
99+
i: usize,
100+
j: usize,
101+
dim_b: usize,
102+
dim_b_i32: i32,
103+
c: f64,
104+
s_abs: f64,
105+
phase: Complex64,
106+
phase_conj: Complex64,
107+
) {
108+
let row_i = std::slice::from_raw_parts_mut(vec_ptr.add(i * dim_b), dim_b);
109+
let row_j = std::slice::from_raw_parts_mut(vec_ptr.add(j * dim_b), dim_b);
110+
// Use zdrot and zscal because zrot is not currently available
111+
// See https://github.com/qiskit-community/ffsim/issues/28
112+
zscal(dim_b_i32, phase_conj, row_i, 1);
113+
zdrot(dim_b_i32, row_i, 1, row_j, 1, c, s_abs);
114+
zscal(dim_b_i32, phase, row_i, 1);
115+
}

0 commit comments

Comments
 (0)