@@ -1074,6 +1074,27 @@ where
1074
1074
/// The key may be any borrowed form of the map's key type, but `Hash` and `Eq` on the borrowed
1075
1075
/// form must match those for the key type.
1076
1076
pub fn remove < ' g , Q > ( & ' g self , key : & Q , guard : & ' g Guard ) -> Option < & ' g V >
1077
+ where
1078
+ K : Borrow < Q > ,
1079
+ Q : ?Sized + Hash + Eq ,
1080
+ {
1081
+ self . replace_node ( key, None , None , guard)
1082
+ }
1083
+
1084
+ /// Replaces node value with v, conditional upon match of cv.
1085
+ /// If resulting value does not exist it removes the key (and its corresponding value) from this map.
1086
+ /// This method does nothing if the key is not in the map.
1087
+ /// Returns the previous value associated with the given key.
1088
+ ///
1089
+ /// The key may be any borrowed form of the map's key type, but `Hash` and `Eq` on the borrowed
1090
+ /// form must match those for the key type.
1091
+ fn replace_node < ' g , Q > (
1092
+ & ' g self ,
1093
+ key : & Q ,
1094
+ new_value : Option < V > ,
1095
+ observed_value : Option < Shared < ' g , V > > ,
1096
+ guard : & ' g Guard ,
1097
+ ) -> Option < & ' g V >
1077
1098
where
1078
1099
K : Borrow < Q > ,
1079
1100
Q : ?Sized + Hash + Eq ,
@@ -1152,26 +1173,36 @@ where
1152
1173
let next = n. next . load ( Ordering :: SeqCst , guard) ;
1153
1174
if n. hash == h && n. key . borrow ( ) == key {
1154
1175
let ev = n. value . load ( Ordering :: SeqCst , guard) ;
1155
- old_val = Some ( ev) ;
1156
-
1157
- // remove the BinEntry containing the removed key value pair from the bucket
1158
- if !pred. is_null ( ) {
1159
- // either by changing the pointer of the previous BinEntry, if present
1160
- // safety: as above
1161
- unsafe { pred. deref ( ) }
1162
- . as_node ( )
1163
- . unwrap ( )
1164
- . next
1165
- . store ( next, Ordering :: SeqCst ) ;
1166
- } else {
1167
- // or by setting the next node as the first BinEntry if there is no previous entry
1168
- t. store_bin ( i, next) ;
1169
- }
1170
-
1171
- // in either case, mark the BinEntry as garbage, since it was just removed
1172
- // safety: as for val below / in put
1173
- unsafe { guard. defer_destroy ( e) } ;
1174
1176
1177
+ // just remove the node if the value is the one we expected at method call
1178
+ if observed_value. map ( |ov| ov == ev) . unwrap_or ( true ) {
1179
+ // found the node but we have a new value to replace the old one
1180
+ if let Some ( nv) = new_value {
1181
+ n. value . store ( Owned :: new ( nv) , Ordering :: SeqCst ) ;
1182
+ // we are just replacing entry value and we do not want to remove the node
1183
+ // so we stop iterating here
1184
+ break ;
1185
+ }
1186
+ // we remember the old value so that we can return it and mark it for deletion below
1187
+ old_val = Some ( ev) ;
1188
+ // remove the BinEntry containing the removed key value pair from the bucket
1189
+ if !pred. is_null ( ) {
1190
+ // either by changing the pointer of the previous BinEntry, if present
1191
+ // safety: as above
1192
+ unsafe { pred. deref ( ) }
1193
+ . as_node ( )
1194
+ . unwrap ( )
1195
+ . next
1196
+ . store ( next, Ordering :: SeqCst ) ;
1197
+ } else {
1198
+ // or by setting the next node as the first BinEntry if there is no previous entry
1199
+ t. store_bin ( i, next) ;
1200
+ }
1201
+
1202
+ // in either case, mark the BinEntry as garbage, since it was just removed
1203
+ // safety: as for val below / in put
1204
+ unsafe { guard. defer_destroy ( e) } ;
1205
+ }
1175
1206
// since the key was found and only one node exists per key, we can break here
1176
1207
break ;
1177
1208
}
@@ -1218,6 +1249,46 @@ where
1218
1249
None
1219
1250
}
1220
1251
1252
+ /// Retains only the elements specified by the predicate.
1253
+ ///
1254
+ /// In other words, remove all pairs (k, v) such that f(&k,&v) returns false.
1255
+ ///
1256
+ /// If `f` returns `false` for a given key/value pair, but the value for that pair is concurrently
1257
+ /// modified before the removal takes place, the entry will not be removed.
1258
+ /// If you want the removal to happen even in the case of concurrent modification, use [`HashMap::retain_force`].
1259
+ pub fn retain < F > ( & mut self , mut f : F )
1260
+ where
1261
+ F : FnMut ( & K , & V ) -> bool ,
1262
+ {
1263
+ let guard = epoch:: pin ( ) ;
1264
+ // removed selected keys
1265
+ for ( k, v) in self . iter ( & guard) {
1266
+ if !f ( k, v) {
1267
+ let old_value: Shared < ' _ , V > = Shared :: from ( v as * const V ) ;
1268
+ self . replace_node ( k, None , Some ( old_value) , & guard) ;
1269
+ }
1270
+ }
1271
+ }
1272
+
1273
+ /// Retains only the elements specified by the predicate.
1274
+ ///
1275
+ /// In other words, remove all pairs (k, v) such that f(&k,&v) returns false.
1276
+ ///
1277
+ /// This method always deletes any key/value pair that `f` returns `false` for,
1278
+ /// even if if the value is updated concurrently. If you do not want that behavior, use [`HashMap::retain`].
1279
+ pub fn retain_force < F > ( & mut self , mut f : F )
1280
+ where
1281
+ F : FnMut ( & K , & V ) -> bool ,
1282
+ {
1283
+ let guard = epoch:: pin ( ) ;
1284
+ // removed selected keys
1285
+ for ( k, v) in self . iter ( & guard) {
1286
+ if !f ( k, v) {
1287
+ self . replace_node ( k, None , None , & guard) ;
1288
+ }
1289
+ }
1290
+ }
1291
+
1221
1292
/// An iterator visiting all key-value pairs in arbitrary order.
1222
1293
/// The iterator element type is `(&'g K, &'g V)`.
1223
1294
///
@@ -1604,5 +1675,54 @@ mod tests {
1604
1675
/// drop(guard);
1605
1676
/// drop(r);
1606
1677
/// ```
1678
+
1679
+ #[ test]
1680
+ fn replace_empty ( ) {
1681
+ let map = HashMap :: < usize , usize > :: new ( ) ;
1682
+
1683
+ {
1684
+ let guard = epoch:: pin ( ) ;
1685
+ let old = map. replace_node ( & 42 , None , None , & guard) ;
1686
+ assert ! ( old. is_none( ) ) ;
1687
+ }
1688
+ }
1689
+
1690
+ #[ test]
1691
+ fn replace_existing ( ) {
1692
+ let map = HashMap :: < usize , usize > :: new ( ) ;
1693
+ {
1694
+ let guard = epoch:: pin ( ) ;
1695
+ map. insert ( 42 , 42 , & guard) ;
1696
+ let old = map. replace_node ( & 42 , Some ( 10 ) , None , & guard) ;
1697
+ assert ! ( old. is_none( ) ) ;
1698
+ assert_eq ! ( * map. get( & 42 , & guard) . unwrap( ) , 10 ) ;
1699
+ }
1700
+ }
1701
+
1702
+ #[ test]
1703
+ fn replace_existing_observed_value_matching ( ) {
1704
+ let map = HashMap :: < usize , usize > :: new ( ) ;
1705
+ {
1706
+ let guard = epoch:: pin ( ) ;
1707
+ map. insert ( 42 , 42 , & guard) ;
1708
+ let observed_value = Shared :: from ( map. get ( & 42 , & guard) . unwrap ( ) as * const _ ) ;
1709
+ let old = map. replace_node ( & 42 , Some ( 10 ) , Some ( observed_value) , & guard) ;
1710
+ assert ! ( old. is_none( ) ) ;
1711
+ assert_eq ! ( * map. get( & 42 , & guard) . unwrap( ) , 10 ) ;
1712
+ }
1713
+ }
1714
+
1715
+ #[ test]
1716
+ fn replace_existing_observed_value_non_matching ( ) {
1717
+ let map = HashMap :: < usize , usize > :: new ( ) ;
1718
+ {
1719
+ let guard = epoch:: pin ( ) ;
1720
+ map. insert ( 42 , 42 , & guard) ;
1721
+ let old = map. replace_node ( & 42 , Some ( 10 ) , Some ( Shared :: null ( ) ) , & guard) ;
1722
+ assert ! ( old. is_none( ) ) ;
1723
+ assert_eq ! ( * map. get( & 42 , & guard) . unwrap( ) , 42 ) ;
1724
+ }
1725
+ }
1726
+
1607
1727
#[ allow( dead_code) ]
1608
1728
struct CompileFailTests ;
0 commit comments