Skip to content

Commit 78d334b

Browse files
committed
Recursively grow stack on heap whenever necessary.
* Add test with large array of equal floats. * Enable optimization for test profile to reduce execution time.
1 parent 3e06315 commit 78d334b

File tree

3 files changed

+56
-26
lines changed

3 files changed

+56
-26
lines changed

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ num-traits = "0.2"
2323
rand = "0.8.3"
2424
itertools = { version = "0.10.0", default-features = false }
2525
indexmap = "1.6.2"
26+
stacker = "0.1.15"
2627

2728
[dev-dependencies]
2829
ndarray = { version = "0.15.0", features = ["approx"] }
@@ -44,3 +45,6 @@ harness = false
4445
[[bench]]
4546
name = "deviation"
4647
harness = false
48+
49+
[profile.test]
50+
opt-level = 2

src/sort.rs

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
use indexmap::IndexMap;
22
use ndarray::prelude::*;
33
use ndarray::{Data, DataMut, Slice};
4+
use stacker::maybe_grow;
5+
6+
/// Guaranteed stack size per recursion step of 1 MiB.
7+
const RED_ZONE: usize = 1_024 * 1_024;
8+
/// New stack space of 8 MiB to allocate if within [`RED_ZONE`].
9+
const STACK_SIZE: usize = 8 * RED_ZONE;
410

511
/// Methods for sorting and partitioning 1-D arrays.
612
pub trait Sort1dExt<A, S>
@@ -357,7 +363,9 @@ where
357363
// Since `!indexes.is_empty()` and indexes must be in-bounds, `array` must
358364
// be non-empty.
359365
let mut values = vec![array[0].clone(); indexes.len()];
360-
_get_many_from_sorted_mut_unchecked(array.view_mut(), &mut indexes.to_owned(), &mut values);
366+
maybe_grow(RED_ZONE, STACK_SIZE, || {
367+
_get_many_from_sorted_mut_unchecked(array.view_mut(), &mut indexes.to_owned(), &mut values);
368+
});
361369

362370
// We convert the vector to a more search-friendly `IndexMap`.
363371
indexes.iter().cloned().zip(values.into_iter()).collect()
@@ -451,21 +459,25 @@ fn _get_many_from_sorted_mut_unchecked<A>(
451459

452460
// We search recursively for the values corresponding to indexes strictly less than
453461
// `pivot_index` in the lower partition.
454-
_get_many_from_sorted_mut_unchecked(
455-
array.slice_axis_mut(Axis(0), Slice::from(..pivot_index)),
456-
lower_indexes,
457-
lower_values,
458-
);
462+
maybe_grow(RED_ZONE, STACK_SIZE, || {
463+
_get_many_from_sorted_mut_unchecked(
464+
array.slice_axis_mut(Axis(0), Slice::from(..pivot_index)),
465+
lower_indexes,
466+
lower_values,
467+
);
468+
});
459469

460470
// We search recursively for the values corresponding to indexes greater than or equal
461471
// `pivot_index` in the upper partition. Since only the upper partition of the array is
462472
// passed in, the indexes need to be shifted by length of the lower partition.
463473
upper_indexes.iter_mut().for_each(|x| *x -= pivot_index + 1);
464-
_get_many_from_sorted_mut_unchecked(
465-
array.slice_axis_mut(Axis(0), Slice::from(pivot_index + 1..)),
466-
upper_indexes,
467-
upper_values,
468-
);
474+
maybe_grow(RED_ZONE, STACK_SIZE, || {
475+
_get_many_from_sorted_mut_unchecked(
476+
array.slice_axis_mut(Axis(0), Slice::from(pivot_index + 1..)),
477+
upper_indexes,
478+
upper_values,
479+
);
480+
});
469481

470482
return;
471483
}
@@ -519,32 +531,38 @@ fn _get_many_from_sorted_mut_unchecked<A>(
519531

520532
// We search recursively for the values corresponding to indexes strictly less than
521533
// `lower_index` in the lower partition.
522-
_get_many_from_sorted_mut_unchecked(
523-
array.slice_axis_mut(Axis(0), Slice::from(..lower_index)),
524-
lower_indexes,
525-
lower_values,
526-
);
534+
maybe_grow(RED_ZONE, STACK_SIZE, || {
535+
_get_many_from_sorted_mut_unchecked(
536+
array.slice_axis_mut(Axis(0), Slice::from(..lower_index)),
537+
lower_indexes,
538+
lower_values,
539+
);
540+
});
527541

528542
// We search recursively for the values corresponding to indexes greater than or equal
529543
// `lower_index` in the inner partition, that is between the lower and upper partition. Since
530544
// only the inner partition of the array is passed in, the indexes need to be shifted by length
531545
// of the lower partition.
532546
inner_indexes.iter_mut().for_each(|x| *x -= lower_index + 1);
533-
_get_many_from_sorted_mut_unchecked(
534-
array.slice_axis_mut(Axis(0), Slice::from(lower_index + 1..upper_index)),
535-
inner_indexes,
536-
inner_values,
537-
);
547+
maybe_grow(RED_ZONE, STACK_SIZE, || {
548+
_get_many_from_sorted_mut_unchecked(
549+
array.slice_axis_mut(Axis(0), Slice::from(lower_index + 1..upper_index)),
550+
inner_indexes,
551+
inner_values,
552+
);
553+
});
538554

539555
// We search recursively for the values corresponding to indexes greater than or equal
540556
// `upper_index` in the upper partition. Since only the upper partition of the array is passed
541557
// in, the indexes need to be shifted by the combined length of the lower and inner partition.
542558
upper_indexes.iter_mut().for_each(|x| *x -= upper_index + 1);
543-
_get_many_from_sorted_mut_unchecked(
544-
array.slice_axis_mut(Axis(0), Slice::from(upper_index + 1..)),
545-
upper_indexes,
546-
upper_values,
547-
);
559+
maybe_grow(RED_ZONE, STACK_SIZE, || {
560+
_get_many_from_sorted_mut_unchecked(
561+
array.slice_axis_mut(Axis(0), Slice::from(upper_index + 1..)),
562+
upper_indexes,
563+
upper_values,
564+
);
565+
});
548566
}
549567

550568
/// Equally space `sample` indexes around the center of `array` and sort them by their values.

tests/sort.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use ndarray::prelude::*;
22
use ndarray_stats::Sort1dExt;
3+
use ndarray_stats::{interpolate::Linear, Quantile1dExt};
4+
use noisy_float::types::{n64, N64};
35
use quickcheck_macros::quickcheck;
46

57
#[test]
@@ -63,6 +65,12 @@ fn test_dual_partition_mut() {
6365
}
6466
}
6567

68+
#[test]
69+
fn test_quantile_mut_with_large_array_of_equal_floats() {
70+
let mut array: Array1<N64> = Array1::ones(100000);
71+
array.quantile_mut(n64(0.5), &Linear).unwrap();
72+
}
73+
6674
#[test]
6775
fn test_sorted_get_mut() {
6876
let a = arr1(&[1, 3, 2, 10]);

0 commit comments

Comments
 (0)