@@ -870,6 +870,64 @@ where
870
870
}
871
871
}
872
872
873
+ /// Retains only the elements specified by the predicate until the predicate fails.
874
+ ///
875
+ /// In other words, remove all elements `e` such that `f(&e)` returns `Ok(false)` until
876
+ /// `f(&e)` returns `Err(())`
877
+ ///
878
+ /// # Examples
879
+ ///
880
+ /// ```
881
+ /// # #[cfg(feature = "nightly")]
882
+ /// # fn test() {
883
+ /// use hashbrown::{HashTable, DefaultHashBuilder};
884
+ /// use std::hash::BuildHasher;
885
+ ///
886
+ /// let mut table = HashTable::new();
887
+ /// let hasher = DefaultHashBuilder::default();
888
+ /// let hasher = |val: &_| {
889
+ /// use core::hash::Hasher;
890
+ /// let mut state = hasher.build_hasher();
891
+ /// core::hash::Hash::hash(&val, &mut state);
892
+ /// state.finish()
893
+ /// };
894
+ /// let mut removed = 0;
895
+ /// for x in 1..=8 {
896
+ /// table.insert_unique(hasher(&x), x, hasher);
897
+ /// }
898
+ /// table.retain_with_break(|&mut v| if removed < 3 {
899
+ /// if v % 2 == 0 {
900
+ /// Ok(true)
901
+ /// } else {
902
+ /// removed += 1;
903
+ /// Ok(false)
904
+ /// }
905
+ /// } else {
906
+ /// Err(())
907
+ /// });
908
+ /// assert_eq!(table.len(), 5);
909
+ /// # }
910
+ /// # fn main() {
911
+ /// # #[cfg(feature = "nightly")]
912
+ /// # test()
913
+ /// # }
914
+ /// ```
915
+ pub fn retain_with_break (
916
+ & mut self ,
917
+ mut f : impl FnMut ( & mut T ) -> core:: result:: Result < bool , ( ) > ,
918
+ ) {
919
+ // Here we only use `iter` as a temporary, preventing use-after-free
920
+ unsafe {
921
+ for item in self . raw . iter ( ) {
922
+ match f ( item. as_mut ( ) ) {
923
+ Ok ( false ) => self . raw . erase ( item) ,
924
+ Err ( _) => break ,
925
+ _ => continue ,
926
+ }
927
+ }
928
+ }
929
+ }
930
+
873
931
/// Clears the set, returning all elements in an iterator.
874
932
///
875
933
/// # Examples
@@ -2372,12 +2430,49 @@ impl<T, F, A: Allocator> FusedIterator for ExtractIf<'_, T, F, A> where F: FnMut
2372
2430
2373
2431
#[ cfg( test) ]
2374
2432
mod tests {
2433
+ use crate :: DefaultHashBuilder ;
2434
+
2375
2435
use super :: HashTable ;
2376
2436
2437
+ use core:: hash:: BuildHasher ;
2377
2438
#[ test]
2378
2439
fn test_allocation_info ( ) {
2379
2440
assert_eq ! ( HashTable :: <( ) >:: new( ) . allocation_size( ) , 0 ) ;
2380
2441
assert_eq ! ( HashTable :: <u32 >:: new( ) . allocation_size( ) , 0 ) ;
2381
2442
assert ! ( HashTable :: <u32 >:: with_capacity( 1 ) . allocation_size( ) > core:: mem:: size_of:: <u32 >( ) ) ;
2382
2443
}
2444
+
2445
+ #[ test]
2446
+ fn test_retain_with_break ( ) {
2447
+ let mut table = HashTable :: new ( ) ;
2448
+ let hasher = DefaultHashBuilder :: default ( ) ;
2449
+ let hasher = |val : & _ | {
2450
+ use core:: hash:: Hasher ;
2451
+ let mut state = hasher. build_hasher ( ) ;
2452
+ core:: hash:: Hash :: hash ( & val, & mut state) ;
2453
+ state. finish ( )
2454
+ } ;
2455
+ for x in 0 ..100 {
2456
+ table. insert_unique ( hasher ( & x) , x, hasher) ;
2457
+ }
2458
+ // looping and removing any value > 50, but stop after 40 iterations
2459
+ let mut removed = 0 ;
2460
+ table. retain_with_break ( |& mut v| {
2461
+ if removed < 40 {
2462
+ if v > 50 {
2463
+ removed += 1 ;
2464
+ Ok ( false )
2465
+ } else {
2466
+ Ok ( true )
2467
+ }
2468
+ } else {
2469
+ Err ( ( ) )
2470
+ }
2471
+ } ) ;
2472
+ assert_eq ! ( table. len( ) , 60 ) ;
2473
+ // check nothing up to 50 is removed
2474
+ for v in 0 ..=50 {
2475
+ assert_eq ! ( table. find( hasher( & v) , |& val| val == v) , Some ( & v) ) ;
2476
+ }
2477
+ }
2383
2478
}
0 commit comments