Skip to content

Commit 9544337

Browse files
authored
Merge pull request #36 from danielSanchezQ/retain
retain method
2 parents 5fe084b + 6b62e60 commit 9544337

File tree

2 files changed

+186
-19
lines changed

2 files changed

+186
-19
lines changed

src/map.rs

Lines changed: 139 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,6 +1074,27 @@ where
10741074
/// The key may be any borrowed form of the map's key type, but `Hash` and `Eq` on the borrowed
10751075
/// form must match those for the key type.
10761076
pub fn remove<'g, Q>(&'g self, key: &Q, guard: &'g Guard) -> Option<&'g V>
1077+
where
1078+
K: Borrow<Q>,
1079+
Q: ?Sized + Hash + Eq,
1080+
{
1081+
self.replace_node(key, None, None, guard)
1082+
}
1083+
1084+
/// Replaces node value with v, conditional upon match of cv.
1085+
/// If resulting value does not exist it removes the key (and its corresponding value) from this map.
1086+
/// This method does nothing if the key is not in the map.
1087+
/// Returns the previous value associated with the given key.
1088+
///
1089+
/// The key may be any borrowed form of the map's key type, but `Hash` and `Eq` on the borrowed
1090+
/// form must match those for the key type.
1091+
fn replace_node<'g, Q>(
1092+
&'g self,
1093+
key: &Q,
1094+
new_value: Option<V>,
1095+
observed_value: Option<Shared<'g, V>>,
1096+
guard: &'g Guard,
1097+
) -> Option<&'g V>
10771098
where
10781099
K: Borrow<Q>,
10791100
Q: ?Sized + Hash + Eq,
@@ -1152,26 +1173,36 @@ where
11521173
let next = n.next.load(Ordering::SeqCst, guard);
11531174
if n.hash == h && n.key.borrow() == key {
11541175
let ev = n.value.load(Ordering::SeqCst, guard);
1155-
old_val = Some(ev);
1156-
1157-
// remove the BinEntry containing the removed key value pair from the bucket
1158-
if !pred.is_null() {
1159-
// either by changing the pointer of the previous BinEntry, if present
1160-
// safety: as above
1161-
unsafe { pred.deref() }
1162-
.as_node()
1163-
.unwrap()
1164-
.next
1165-
.store(next, Ordering::SeqCst);
1166-
} else {
1167-
// or by setting the next node as the first BinEntry if there is no previous entry
1168-
t.store_bin(i, next);
1169-
}
1170-
1171-
// in either case, mark the BinEntry as garbage, since it was just removed
1172-
// safety: as for val below / in put
1173-
unsafe { guard.defer_destroy(e) };
11741176

1177+
// just remove the node if the value is the one we expected at method call
1178+
if observed_value.map(|ov| ov == ev).unwrap_or(true) {
1179+
// found the node but we have a new value to replace the old one
1180+
if let Some(nv) = new_value {
1181+
n.value.store(Owned::new(nv), Ordering::SeqCst);
1182+
// we are just replacing entry value and we do not want to remove the node
1183+
// so we stop iterating here
1184+
break;
1185+
}
1186+
// we remember the old value so that we can return it and mark it for deletion below
1187+
old_val = Some(ev);
1188+
// remove the BinEntry containing the removed key value pair from the bucket
1189+
if !pred.is_null() {
1190+
// either by changing the pointer of the previous BinEntry, if present
1191+
// safety: as above
1192+
unsafe { pred.deref() }
1193+
.as_node()
1194+
.unwrap()
1195+
.next
1196+
.store(next, Ordering::SeqCst);
1197+
} else {
1198+
// or by setting the next node as the first BinEntry if there is no previous entry
1199+
t.store_bin(i, next);
1200+
}
1201+
1202+
// in either case, mark the BinEntry as garbage, since it was just removed
1203+
// safety: as for val below / in put
1204+
unsafe { guard.defer_destroy(e) };
1205+
}
11751206
// since the key was found and only one node exists per key, we can break here
11761207
break;
11771208
}
@@ -1218,6 +1249,46 @@ where
12181249
None
12191250
}
12201251

1252+
/// Retains only the elements specified by the predicate.
1253+
///
1254+
/// In other words, remove all pairs (k, v) such that f(&k,&v) returns false.
1255+
///
1256+
/// If `f` returns `false` for a given key/value pair, but the value for that pair is concurrently
1257+
/// modified before the removal takes place, the entry will not be removed.
1258+
/// If you want the removal to happen even in the case of concurrent modification, use [`HashMap::retain_force`].
1259+
pub fn retain<F>(&mut self, mut f: F)
1260+
where
1261+
F: FnMut(&K, &V) -> bool,
1262+
{
1263+
let guard = epoch::pin();
1264+
// removed selected keys
1265+
for (k, v) in self.iter(&guard) {
1266+
if !f(k, v) {
1267+
let old_value: Shared<'_, V> = Shared::from(v as *const V);
1268+
self.replace_node(k, None, Some(old_value), &guard);
1269+
}
1270+
}
1271+
}
1272+
1273+
/// Retains only the elements specified by the predicate.
1274+
///
1275+
/// In other words, remove all pairs (k, v) such that f(&k,&v) returns false.
1276+
///
1277+
/// This method always deletes any key/value pair that `f` returns `false` for,
1278+
/// even if if the value is updated concurrently. If you do not want that behavior, use [`HashMap::retain`].
1279+
pub fn retain_force<F>(&mut self, mut f: F)
1280+
where
1281+
F: FnMut(&K, &V) -> bool,
1282+
{
1283+
let guard = epoch::pin();
1284+
// removed selected keys
1285+
for (k, v) in self.iter(&guard) {
1286+
if !f(k, v) {
1287+
self.replace_node(k, None, None, &guard);
1288+
}
1289+
}
1290+
}
1291+
12211292
/// An iterator visiting all key-value pairs in arbitrary order.
12221293
/// The iterator element type is `(&'g K, &'g V)`.
12231294
///
@@ -1604,5 +1675,54 @@ mod tests {
16041675
/// drop(guard);
16051676
/// drop(r);
16061677
/// ```
1678+
1679+
#[test]
1680+
fn replace_empty() {
1681+
let map = HashMap::<usize, usize>::new();
1682+
1683+
{
1684+
let guard = epoch::pin();
1685+
let old = map.replace_node(&42, None, None, &guard);
1686+
assert!(old.is_none());
1687+
}
1688+
}
1689+
1690+
#[test]
1691+
fn replace_existing() {
1692+
let map = HashMap::<usize, usize>::new();
1693+
{
1694+
let guard = epoch::pin();
1695+
map.insert(42, 42, &guard);
1696+
let old = map.replace_node(&42, Some(10), None, &guard);
1697+
assert!(old.is_none());
1698+
assert_eq!(*map.get(&42, &guard).unwrap(), 10);
1699+
}
1700+
}
1701+
1702+
#[test]
1703+
fn replace_existing_observed_value_matching() {
1704+
let map = HashMap::<usize, usize>::new();
1705+
{
1706+
let guard = epoch::pin();
1707+
map.insert(42, 42, &guard);
1708+
let observed_value = Shared::from(map.get(&42, &guard).unwrap() as *const _);
1709+
let old = map.replace_node(&42, Some(10), Some(observed_value), &guard);
1710+
assert!(old.is_none());
1711+
assert_eq!(*map.get(&42, &guard).unwrap(), 10);
1712+
}
1713+
}
1714+
1715+
#[test]
1716+
fn replace_existing_observed_value_non_matching() {
1717+
let map = HashMap::<usize, usize>::new();
1718+
{
1719+
let guard = epoch::pin();
1720+
map.insert(42, 42, &guard);
1721+
let old = map.replace_node(&42, Some(10), Some(Shared::null()), &guard);
1722+
assert!(old.is_none());
1723+
assert_eq!(*map.get(&42, &guard).unwrap(), 42);
1724+
}
1725+
}
1726+
16071727
#[allow(dead_code)]
16081728
struct CompileFailTests;

tests/basic.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,3 +445,50 @@ fn from_iter_empty() {
445445

446446
assert_eq!(map.len(), 0)
447447
}
448+
449+
#[test]
450+
fn retain_empty() {
451+
let mut map = HashMap::<&'static str, u32>::new();
452+
map.retain(|_, _| false);
453+
assert_eq!(map.len(), 0);
454+
}
455+
456+
#[test]
457+
fn retain_all_false() {
458+
let mut map: HashMap<u32, u32> = (0..10 as u32).map(|x| (x, x)).collect();
459+
map.retain(|_, _| false);
460+
assert_eq!(map.len(), 0);
461+
}
462+
463+
#[test]
464+
fn retain_all_true() {
465+
let size = 10usize;
466+
let mut map: HashMap<usize, usize> = (0..size).map(|x| (x, x)).collect();
467+
map.retain(|_, _| true);
468+
assert_eq!(map.len(), size);
469+
}
470+
471+
#[test]
472+
fn retain_some() {
473+
let mut map: HashMap<u32, u32> = (0..10).map(|x| (x, x)).collect();
474+
let expected_map: HashMap<u32, u32> = (5..10).map(|x| (x, x)).collect();
475+
map.retain(|_, v| *v >= 5);
476+
assert_eq!(map.len(), 5);
477+
assert_eq!(map, expected_map);
478+
}
479+
480+
#[test]
481+
fn retain_force_empty() {
482+
let mut map = HashMap::<&'static str, u32>::new();
483+
map.retain_force(|_, _| false);
484+
assert_eq!(map.len(), 0);
485+
}
486+
487+
#[test]
488+
fn retain_force_some() {
489+
let mut map: HashMap<u32, u32> = (0..10).map(|x| (x, x)).collect();
490+
let expected_map: HashMap<u32, u32> = (5..10).map(|x| (x, x)).collect();
491+
map.retain_force(|_, v| *v >= 5);
492+
assert_eq!(map.len(), 5);
493+
assert_eq!(map, expected_map);
494+
}

0 commit comments

Comments
 (0)