diff --git a/instant-distance/src/lib.rs b/instant-distance/src/lib.rs index a9f06f6..5e81514 100644 --- a/instant-distance/src/lib.rs +++ b/instant-distance/src/lib.rs @@ -170,6 +170,12 @@ where pub fn get(&self, i: usize, search: &Search) -> Option> { Some(MapItem::from(self.hnsw.get(i, search)?, self)) } + + pub fn insert(&mut self, point: P, value: V) -> Result> { + let point_id = self.hnsw.insert(point, 100, Some(Heuristic::default())); + self.values.push(value); + Ok(point_id) + } } pub struct MapItem<'a, P, V> { @@ -394,6 +400,55 @@ where pub fn get(&self, i: usize, search: &Search) -> Option> { Some(Item::new(search.nearest.get(i).copied()?, self)) } + + pub fn insert( + &mut self, + point: P, + ef_construction: usize, + heuristic: Option, + ) -> PointId { + let new_pid = self.points.len(); + let new_point_id = PointId(new_pid as u32); + + self.points.push(point); + self.zero.push(ZeroNode::default()); + + let zeros = self + .zero + .iter() + .map(|z| RwLock::new(z.clone())) + .collect::>(); + + let top = if self.layers.is_empty() { + LayerId(0) + } else { + LayerId(self.layers.len()) + }; + + let construction = Construction { + zero: zeros.as_slice(), + pool: SearchPool::new(self.points.len()), + top, + points: self.points.as_slice(), + heuristic, + ef_construction, + #[cfg(feature = "indicatif")] + progress: None, + #[cfg(feature = "indicatif")] + done: AtomicUsize::new(0), + }; + + let new_layer = construction.top; + construction.insert(new_point_id, new_layer, &self.layers); + + self.zero = construction + .zero + .iter() + .map(|node| node.read().clone()) + .collect(); + + new_point_id + } } pub struct Item<'a, P> { diff --git a/instant-distance/tests/all.rs b/instant-distance/tests/all.rs index b9fa973..c895c5a 100644 --- a/instant-distance/tests/all.rs +++ b/instant-distance/tests/all.rs @@ -96,3 +96,85 @@ impl instant_distance::Point for Point { ((self.0 - other.0).powi(2) + (self.1 - other.1).powi(2)).sqrt() } } + +#[test] +#[allow(clippy::float_cmp, clippy::approx_constant)] +fn incremental_insert() { + let points = (0..4) + .map(|i| Point(i as f32, i as f32)) + .collect::>(); + let values = vec!["zero", "one", "two", "three"]; + let seed = ThreadRng::default().gen::(); + let builder = Builder::default().seed(seed); + + let mut map = builder.build(points, values); + + map.insert(Point(4.0, 4.0), "four").expect("Should insert"); + + let mut search = Search::default(); + + for (i, item) in map.search(&Point(4.0, 4.0), &mut search).enumerate() { + match i { + 0 => { + assert_eq!(item.distance, 0.0); + assert_eq!(item.value, &"four"); + } + 1 => { + assert_eq!(item.distance, 1.4142135); + assert!(item.value == &"three"); + } + 2 => { + assert_eq!(item.distance, 2.828427); + assert!(item.value == &"two"); + } + 3 => { + assert_eq!(item.distance, 4.2426405); + assert!(item.value == &"one"); + } + 4 => { + assert_eq!(item.distance, 5.656854); + assert!(item.value == &"zero"); + } + _ => unreachable!(), + } + } + + // Note + // This has the same expected results as incremental_insert but builds + // the whole map in one go. Only here for comparison. + { + let points = (0..5) + .map(|i| Point(i as f32, i as f32)) + .collect::>(); + let values = vec!["zero", "one", "two", "three", "four"]; + let seed = ThreadRng::default().gen::(); + let builder = Builder::default().seed(seed); + let map = builder.build(points, values); + let mut search = Search::default(); + for (i, item) in map.search(&Point(4.0, 4.0), &mut search).enumerate() { + match i { + 0 => { + assert_eq!(item.distance, 0.0); + assert_eq!(item.value, &"four"); + } + 1 => { + assert_eq!(item.distance, 1.4142135); + assert!(item.value == &"three"); + } + 2 => { + assert_eq!(item.distance, 2.828427); + assert!(item.value == &"two"); + } + 3 => { + assert_eq!(item.distance, 4.2426405); + assert!(item.value == &"one"); + } + 4 => { + assert_eq!(item.distance, 5.656854); + assert!(item.value == &"zero"); + } + _ => unreachable!(), + } + } + } +}