@@ -870,6 +870,64 @@ where
870870 }
871871 }
872872
873+ /// Retains only the elements specified by the predicate until the predicate fails.
874+ ///
875+ /// In other words, remove all elements `e` such that `f(&e)` returns `Ok(false)` until
876+ /// `f(&e)` returns `Err(())`
877+ ///
878+ /// # Examples
879+ ///
880+ /// ```
881+ /// # #[cfg(feature = "nightly")]
882+ /// # fn test() {
883+ /// use hashbrown::{HashTable, DefaultHashBuilder};
884+ /// use std::hash::BuildHasher;
885+ ///
886+ /// let mut table = HashTable::new();
887+ /// let hasher = DefaultHashBuilder::default();
888+ /// let hasher = |val: &_| {
889+ /// use core::hash::Hasher;
890+ /// let mut state = hasher.build_hasher();
891+ /// core::hash::Hash::hash(&val, &mut state);
892+ /// state.finish()
893+ /// };
894+ /// let mut removed = 0;
895+ /// for x in 1..=8 {
896+ /// table.insert_unique(hasher(&x), x, hasher);
897+ /// }
898+ /// table.retain_with_break(|&mut v| if removed < 3 {
899+ /// if v % 2 == 0 {
900+ /// Ok(true)
901+ /// } else {
902+ /// removed += 1;
903+ /// Ok(false)
904+ /// }
905+ /// } else {
906+ /// Err(())
907+ /// });
908+ /// assert_eq!(table.len(), 5);
909+ /// # }
910+ /// # fn main() {
911+ /// # #[cfg(feature = "nightly")]
912+ /// # test()
913+ /// # }
914+ /// ```
915+ pub fn retain_with_break (
916+ & mut self ,
917+ mut f : impl FnMut ( & mut T ) -> core:: result:: Result < bool , ( ) > ,
918+ ) {
919+ // Here we only use `iter` as a temporary, preventing use-after-free
920+ unsafe {
921+ for item in self . raw . iter ( ) {
922+ match f ( item. as_mut ( ) ) {
923+ Ok ( false ) => self . raw . erase ( item) ,
924+ Err ( _) => break ,
925+ _ => continue ,
926+ }
927+ }
928+ }
929+ }
930+
873931 /// Clears the set, returning all elements in an iterator.
874932 ///
875933 /// # Examples
@@ -2372,12 +2430,49 @@ impl<T, F, A: Allocator> FusedIterator for ExtractIf<'_, T, F, A> where F: FnMut
23722430
23732431#[ cfg( test) ]
23742432mod tests {
2433+ use crate :: DefaultHashBuilder ;
2434+
23752435 use super :: HashTable ;
23762436
2437+ use core:: hash:: BuildHasher ;
23772438 #[ test]
23782439 fn test_allocation_info ( ) {
23792440 assert_eq ! ( HashTable :: <( ) >:: new( ) . allocation_size( ) , 0 ) ;
23802441 assert_eq ! ( HashTable :: <u32 >:: new( ) . allocation_size( ) , 0 ) ;
23812442 assert ! ( HashTable :: <u32 >:: with_capacity( 1 ) . allocation_size( ) > core:: mem:: size_of:: <u32 >( ) ) ;
23822443 }
2444+
2445+ #[ test]
2446+ fn test_retain_with_break ( ) {
2447+ let mut table = HashTable :: new ( ) ;
2448+ let hasher = DefaultHashBuilder :: default ( ) ;
2449+ let hasher = |val : & _ | {
2450+ use core:: hash:: Hasher ;
2451+ let mut state = hasher. build_hasher ( ) ;
2452+ core:: hash:: Hash :: hash ( & val, & mut state) ;
2453+ state. finish ( )
2454+ } ;
2455+ for x in 0 ..100 {
2456+ table. insert_unique ( hasher ( & x) , x, hasher) ;
2457+ }
2458+ // looping and removing any value > 50, but stop after 40 iterations
2459+ let mut removed = 0 ;
2460+ table. retain_with_break ( |& mut v| {
2461+ if removed < 40 {
2462+ if v > 50 {
2463+ removed += 1 ;
2464+ Ok ( false )
2465+ } else {
2466+ Ok ( true )
2467+ }
2468+ } else {
2469+ Err ( ( ) )
2470+ }
2471+ } ) ;
2472+ assert_eq ! ( table. len( ) , 60 ) ;
2473+ // check nothing up to 50 is removed
2474+ for v in 0 ..=50 {
2475+ assert_eq ! ( table. find( hasher( & v) , |& val| val == v) , Some ( & v) ) ;
2476+ }
2477+ }
23832478}
0 commit comments