Skip to content

Commit 17dc9f3

Browse files
authored
Add ordered pairs for FastPair (#252)
* Add ordered_pairs method to FastPair * add tests to fastpair
1 parent c8ec8fe commit 17dc9f3

File tree

1 file changed

+114
-0
lines changed

1 file changed

+114
-0
lines changed

src/algorithm/neighbour/fastpair.rs

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,21 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
173173
}
174174
}
175175

176+
///
177+
/// Return order dissimilarities from closest to furthest
178+
///
179+
#[allow(dead_code)]
180+
pub fn ordered_pairs(&self) -> std::vec::IntoIter<&PairwiseDistance<T>> {
181+
// improvement: implement this to return `impl Iterator<Item = &PairwiseDistance<T>>`
182+
// need to implement trait `Iterator` for `Vec<&PairwiseDistance<T>>`
183+
let mut distances = self
184+
.distances
185+
.values()
186+
.collect::<Vec<&PairwiseDistance<T>>>();
187+
distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
188+
distances.into_iter()
189+
}
190+
176191
//
177192
// Compute distances from input to all other points in data-structure.
178193
// input is the row index of the sample matrix
@@ -588,4 +603,103 @@ mod tests_fastpair {
588603

589604
assert_eq!(closest, min_dissimilarity);
590605
}
606+
607+
#[test]
608+
fn fastpair_ordered_pairs() {
609+
let x = DenseMatrix::<f64>::from_2d_array(&[
610+
&[5.1, 3.5, 1.4, 0.2],
611+
&[4.9, 3.0, 1.4, 0.2],
612+
&[4.7, 3.2, 1.3, 0.2],
613+
&[4.6, 3.1, 1.5, 0.2],
614+
&[5.0, 3.6, 1.4, 0.2],
615+
&[5.4, 3.9, 1.7, 0.4],
616+
&[4.9, 3.1, 1.5, 0.1],
617+
&[7.0, 3.2, 4.7, 1.4],
618+
&[6.4, 3.2, 4.5, 1.5],
619+
&[6.9, 3.1, 4.9, 1.5],
620+
&[5.5, 2.3, 4.0, 1.3],
621+
&[6.5, 2.8, 4.6, 1.5],
622+
&[4.6, 3.4, 1.4, 0.3],
623+
&[5.0, 3.4, 1.5, 0.2],
624+
&[4.4, 2.9, 1.4, 0.2],
625+
])
626+
.unwrap();
627+
let fastpair = FastPair::new(&x).unwrap();
628+
629+
let ordered = fastpair.ordered_pairs();
630+
631+
let mut previous: f64 = -1.0;
632+
for p in ordered {
633+
if previous == -1.0 {
634+
previous = p.distance.unwrap();
635+
} else {
636+
let current = p.distance.unwrap();
637+
assert!(current >= previous);
638+
previous = current;
639+
}
640+
}
641+
}
642+
643+
#[test]
644+
fn test_empty_set() {
645+
let empty_matrix = DenseMatrix::<f64>::zeros(0, 0);
646+
let result = FastPair::new(&empty_matrix);
647+
assert!(result.is_err());
648+
if let Err(e) = result {
649+
assert_eq!(
650+
e,
651+
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
652+
);
653+
}
654+
}
655+
656+
#[test]
657+
fn test_single_point() {
658+
let single_point = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]).unwrap();
659+
let result = FastPair::new(&single_point);
660+
assert!(result.is_err());
661+
if let Err(e) = result {
662+
assert_eq!(
663+
e,
664+
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
665+
);
666+
}
667+
}
668+
669+
#[test]
670+
fn test_two_points() {
671+
let two_points = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
672+
let result = FastPair::new(&two_points);
673+
assert!(result.is_err());
674+
if let Err(e) = result {
675+
assert_eq!(
676+
e,
677+
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
678+
);
679+
}
680+
}
681+
682+
#[test]
683+
fn test_three_identical_points() {
684+
let identical_points =
685+
DenseMatrix::from_2d_array(&[&[1.0, 1.0], &[1.0, 1.0], &[1.0, 1.0]]).unwrap();
686+
let result = FastPair::new(&identical_points);
687+
assert!(result.is_ok());
688+
let fastpair = result.unwrap();
689+
let closest_pair = fastpair.closest_pair();
690+
assert_eq!(closest_pair.distance, Some(0.0));
691+
}
692+
693+
#[test]
694+
fn test_result_unwrapping() {
695+
let valid_matrix =
696+
DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0], &[7.0, 8.0]])
697+
.unwrap();
698+
699+
let result = FastPair::new(&valid_matrix);
700+
assert!(result.is_ok());
701+
702+
// This should not panic
703+
let _fastpair = result.unwrap();
704+
}
591705
}

0 commit comments

Comments
 (0)