@@ -12,9 +12,6 @@ extern crate blas_src;
1212
1313use blas:: zdrot;
1414use blas:: zscal;
15- use ndarray:: s;
16-
17- use ndarray:: Zip ;
1815use numpy:: Complex64 ;
1916use numpy:: PyReadonlyArray1 ;
2017
@@ -31,34 +28,55 @@ pub fn apply_givens_rotation_in_place(
3128 slice1 : PyReadonlyArray1 < usize > ,
3229 slice2 : PyReadonlyArray1 < usize > ,
3330) {
31+ if slice1. is_empty ( ) . unwrap ( ) {
32+ return ;
33+ }
34+
3435 let mut vec = vec. as_array_mut ( ) ;
3536 let slice1 = slice1. as_array ( ) ;
3637 let slice2 = slice2. as_array ( ) ;
37- let shape = vec. shape ( ) ;
38- let dim_b = shape [ 1 ] as i32 ;
38+ let dim_b = vec. shape ( ) [ 1 ] ;
39+ let dim_b_i32 = dim_b as i32 ;
3940 let s_abs = s. norm ( ) ;
4041 let angle = s. arg ( ) ;
4142 let phase = Complex64 :: new ( angle. cos ( ) , angle. sin ( ) ) ;
4243 let phase_conj = phase. conj ( ) ;
4344
44- Zip :: from ( & slice1) . and ( & slice2) . for_each ( |& i, & j| {
45- let ( mut row_i, mut row_j) = vec. multi_slice_mut ( ( s ! [ i, ..] , s ! [ j, ..] ) ) ;
46- match row_i. as_slice_mut ( ) {
47- Some ( row_i) => match row_j. as_slice_mut ( ) {
48- Some ( row_j) => unsafe {
49- // Use zdrot and zscal because zrot is not currently available
50- // See https://github.com/qiskit-community/ffsim/issues/28
51- zscal ( dim_b, phase_conj, row_i, 1 ) ;
52- zdrot ( dim_b, row_i, 1 , row_j, 1 , c, s_abs) ;
53- zscal ( dim_b, phase, row_i, 1 ) ;
54- } ,
55- None => panic ! (
56- "Failed to convert ArrayBase to slice, possibly because the data was not contiguous and in standard order."
57- ) ,
58- } ,
59- None => panic ! (
60- "Failed to convert ArrayBase to slice, possibly because the data was not contiguous and in standard order."
61- ) ,
62- } ;
45+ let n_pairs = slice1. len ( ) ;
46+ let n_threads = std:: env:: var ( "RAYON_NUM_THREADS" )
47+ . ok ( )
48+ . and_then ( |s| s. parse ( ) . ok ( ) )
49+ . unwrap_or_else ( || {
50+ std:: thread:: available_parallelism ( )
51+ . map ( |n| n. get ( ) )
52+ . unwrap_or ( 1 )
53+ } )
54+ . min ( n_pairs) ;
55+ let chunk_size = n_pairs. div_ceil ( n_threads) ;
56+ let vec_ptr = vec. as_mut_ptr ( ) as usize ;
57+ let slice1 = slice1. as_slice ( ) . unwrap ( ) ;
58+ let slice2 = slice2. as_slice ( ) . unwrap ( ) ;
59+
60+ std:: thread:: scope ( |scope| {
61+ for k in 0 ..n_threads {
62+ let start = k * chunk_size;
63+ let end = ( start + chunk_size) . min ( n_pairs) ;
64+ let slice1_chunk = & slice1[ start..end] ;
65+ let slice2_chunk = & slice2[ start..end] ;
66+ scope. spawn ( move || {
67+ let vec_ptr = vec_ptr as * mut Complex64 ;
68+ for ( & i, & j) in slice1_chunk. iter ( ) . zip ( slice2_chunk) {
69+ unsafe {
70+ let row_i = std:: slice:: from_raw_parts_mut ( vec_ptr. add ( i * dim_b) , dim_b) ;
71+ let row_j = std:: slice:: from_raw_parts_mut ( vec_ptr. add ( j * dim_b) , dim_b) ;
72+ // Use zdrot and zscal because zrot is not currently available
73+ // See https://github.com/qiskit-community/ffsim/issues/28
74+ zscal ( dim_b_i32, phase_conj, row_i, 1 ) ;
75+ zdrot ( dim_b_i32, row_i, 1 , row_j, 1 , c, s_abs) ;
76+ zscal ( dim_b_i32, phase, row_i, 1 ) ;
77+ }
78+ }
79+ } ) ;
80+ }
6381 } ) ;
6482}
0 commit comments