Skip to content

Commit b4f71cb

Browse files
committed
parallelize orbital rotation
1 parent be5de0f commit b4f71cb

File tree

1 file changed

+42
-24
lines changed

1 file changed

+42
-24
lines changed

src/gates/orbital_rotation.rs

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@ extern crate blas_src;
1212

1313
use blas::zdrot;
1414
use blas::zscal;
15-
use ndarray::s;
16-
17-
use ndarray::Zip;
1815
use numpy::Complex64;
1916
use numpy::PyReadonlyArray1;
2017

@@ -31,34 +28,55 @@ pub fn apply_givens_rotation_in_place(
3128
slice1: PyReadonlyArray1<usize>,
3229
slice2: PyReadonlyArray1<usize>,
3330
) {
31+
if slice1.is_empty().unwrap() {
32+
return;
33+
}
34+
3435
let mut vec = vec.as_array_mut();
3536
let slice1 = slice1.as_array();
3637
let slice2 = slice2.as_array();
37-
let shape = vec.shape();
38-
let dim_b = shape[1] as i32;
38+
let dim_b = vec.shape()[1];
39+
let dim_b_i32 = dim_b as i32;
3940
let s_abs = s.norm();
4041
let angle = s.arg();
4142
let phase = Complex64::new(angle.cos(), angle.sin());
4243
let phase_conj = phase.conj();
4344

44-
Zip::from(&slice1).and(&slice2).for_each(|&i, &j| {
45-
let (mut row_i, mut row_j) = vec.multi_slice_mut((s![i, ..], s![j, ..]));
46-
match row_i.as_slice_mut() {
47-
Some(row_i) => match row_j.as_slice_mut() {
48-
Some(row_j) => unsafe {
49-
// Use zdrot and zscal because zrot is not currently available
50-
// See https://github.com/qiskit-community/ffsim/issues/28
51-
zscal(dim_b, phase_conj, row_i, 1);
52-
zdrot(dim_b, row_i, 1, row_j, 1, c, s_abs);
53-
zscal(dim_b, phase, row_i, 1);
54-
},
55-
None => panic!(
56-
"Failed to convert ArrayBase to slice, possibly because the data was not contiguous and in standard order."
57-
),
58-
},
59-
None => panic!(
60-
"Failed to convert ArrayBase to slice, possibly because the data was not contiguous and in standard order."
61-
),
62-
};
45+
let n_pairs = slice1.len();
46+
let n_threads = std::env::var("RAYON_NUM_THREADS")
47+
.ok()
48+
.and_then(|s| s.parse().ok())
49+
.unwrap_or_else(|| {
50+
std::thread::available_parallelism()
51+
.map(|n| n.get())
52+
.unwrap_or(1)
53+
})
54+
.min(n_pairs);
55+
let chunk_size = n_pairs.div_ceil(n_threads);
56+
let vec_ptr = vec.as_mut_ptr() as usize;
57+
let slice1 = slice1.as_slice().unwrap();
58+
let slice2 = slice2.as_slice().unwrap();
59+
60+
std::thread::scope(|scope| {
61+
for k in 0..n_threads {
62+
let start = k * chunk_size;
63+
let end = (start + chunk_size).min(n_pairs);
64+
let slice1_chunk = &slice1[start..end];
65+
let slice2_chunk = &slice2[start..end];
66+
scope.spawn(move || {
67+
let vec_ptr = vec_ptr as *mut Complex64;
68+
for (&i, &j) in slice1_chunk.iter().zip(slice2_chunk) {
69+
unsafe {
70+
let row_i = std::slice::from_raw_parts_mut(vec_ptr.add(i * dim_b), dim_b);
71+
let row_j = std::slice::from_raw_parts_mut(vec_ptr.add(j * dim_b), dim_b);
72+
// Use zdrot and zscal because zrot is not currently available
73+
// See https://github.com/qiskit-community/ffsim/issues/28
74+
zscal(dim_b_i32, phase_conj, row_i, 1);
75+
zdrot(dim_b_i32, row_i, 1, row_j, 1, c, s_abs);
76+
zscal(dim_b_i32, phase, row_i, 1);
77+
}
78+
}
79+
});
80+
}
6381
});
6482
}

0 commit comments

Comments
 (0)