Skip to content

Commit 9e94fa7

Browse files
task: remove raw-entry feature from hashbrown dep (#7252)
1 parent 0d234c3 commit 9e94fa7

File tree

3 files changed

+100
-105
lines changed

3 files changed

+100
-105
lines changed

tokio-util/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ slab = { version = "0.4.4", optional = true } # Backs `DelayQueue`
4545
tracing = { version = "0.1.29", default-features = false, features = ["std"], optional = true }
4646

4747
[target.'cfg(tokio_unstable)'.dependencies]
48-
hashbrown = { version = "0.15.0", default-features = false, features = ["raw-entry"], optional = true }
48+
hashbrown = { version = "0.15.0", default-features = false, optional = true }
4949

5050
[dev-dependencies]
5151
tokio = { version = "1.0.0", path = "../tokio", features = ["full"] }

tokio-util/src/task/join_map.rs

Lines changed: 66 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
use hashbrown::hash_map::RawEntryMut;
2-
use hashbrown::HashMap;
1+
use hashbrown::hash_table::Entry;
2+
use hashbrown::{HashMap, HashTable};
33
use std::borrow::Borrow;
44
use std::collections::hash_map::RandomState;
55
use std::fmt;
@@ -103,13 +103,8 @@ use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet};
103103
#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))]
104104
pub struct JoinMap<K, V, S = RandomState> {
105105
/// A map of the [`AbortHandle`]s of the tasks spawned on this `JoinMap`,
106-
/// indexed by their keys and task IDs.
107-
///
108-
/// The [`Key`] type contains both the task's `K`-typed key provided when
109-
/// spawning tasks, and the task's IDs. The IDs are stored here to resolve
110-
/// hash collisions when looking up tasks based on their pre-computed hash
111-
/// (as stored in the `hashes_by_task` map).
112-
tasks_by_key: HashMap<Key<K>, AbortHandle, S>,
106+
/// indexed by their keys.
107+
tasks_by_key: HashTable<(K, AbortHandle)>,
113108

114109
/// A map from task IDs to the hash of the key associated with that task.
115110
///
@@ -125,21 +120,6 @@ pub struct JoinMap<K, V, S = RandomState> {
125120
tasks: JoinSet<V>,
126121
}
127122

128-
/// A [`JoinMap`] key.
129-
///
130-
/// This holds both a `K`-typed key (the actual key as seen by the user), _and_
131-
/// a task ID, so that hash collisions between `K`-typed keys can be resolved
132-
/// using either `K`'s `Eq` impl *or* by checking the task IDs.
133-
///
134-
/// This allows looking up a task using either an actual key (such as when the
135-
/// user queries the map with a key), *or* using a task ID and a hash (such as
136-
/// when removing completed tasks from the map).
137-
#[derive(Debug)]
138-
struct Key<K> {
139-
key: K,
140-
id: Id,
141-
}
142-
143123
impl<K, V> JoinMap<K, V> {
144124
/// Creates a new empty `JoinMap`.
145125
///
@@ -176,7 +156,7 @@ impl<K, V> JoinMap<K, V> {
176156
}
177157
}
178158

179-
impl<K, V, S: Clone> JoinMap<K, V, S> {
159+
impl<K, V, S> JoinMap<K, V, S> {
180160
/// Creates an empty `JoinMap` which will use the given hash builder to hash
181161
/// keys.
182162
///
@@ -226,7 +206,7 @@ impl<K, V, S: Clone> JoinMap<K, V, S> {
226206
#[must_use]
227207
pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
228208
Self {
229-
tasks_by_key: HashMap::with_capacity_and_hasher(capacity, hash_builder.clone()),
209+
tasks_by_key: HashTable::with_capacity(capacity),
230210
hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder),
231211
tasks: JoinSet::new(),
232212
}
@@ -415,33 +395,42 @@ where
415395
self.insert(key, task)
416396
}
417397

418-
fn insert(&mut self, key: K, abort: AbortHandle) {
419-
let hash = self.hash(&key);
398+
fn insert(&mut self, mut key: K, mut abort: AbortHandle) {
399+
let hash_builder = self.hashes_by_task.hasher();
400+
let hash = hash_one(hash_builder, &key);
420401
let id = abort.id();
421-
let map_key = Key { id, key };
422402

423403
// Insert the new key into the map of tasks by keys.
424-
let entry = self
425-
.tasks_by_key
426-
.raw_entry_mut()
427-
.from_hash(hash, |k| k.key == map_key.key);
404+
let entry =
405+
self.tasks_by_key
406+
.entry(hash, |(k, _)| *k == key, |(k, _)| hash_one(hash_builder, k));
428407
match entry {
429-
RawEntryMut::Occupied(mut occ) => {
408+
Entry::Occupied(occ) => {
430409
// There was a previous task spawned with the same key! Cancel
431410
// that task, and remove its ID from the map of hashes by task IDs.
432-
let Key { id: prev_id, .. } = occ.insert_key(map_key);
433-
occ.insert(abort).abort();
434-
let _prev_hash = self.hashes_by_task.remove(&prev_id);
411+
(key, abort) = std::mem::replace(occ.into_mut(), (key, abort));
412+
413+
// Remove the old task ID.
414+
let _prev_hash = self.hashes_by_task.remove(&abort.id());
435415
debug_assert_eq!(Some(hash), _prev_hash);
416+
417+
// Associate the key's hash with the new task's ID, for looking up tasks by ID.
418+
let _prev = self.hashes_by_task.insert(id, hash);
419+
debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
420+
421+
// Note: it's important to drop `key` and abort the task here.
422+
// This defends against any panics during drop handling for causing inconsistent state.
423+
abort.abort();
424+
drop(key);
436425
}
437-
RawEntryMut::Vacant(vac) => {
438-
vac.insert(map_key, abort);
426+
Entry::Vacant(vac) => {
427+
vac.insert((key, abort));
428+
429+
// Associate the key's hash with this task's ID, for looking up tasks by ID.
430+
let _prev = self.hashes_by_task.insert(id, hash);
431+
debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
439432
}
440433
};
441-
442-
// Associate the key's hash with this task's ID, for looking up tasks by ID.
443-
let _prev = self.hashes_by_task.insert(id, hash);
444-
debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
445434
}
446435

447436
/// Waits until one of the tasks in the map completes and returns its
@@ -623,7 +612,7 @@ where
623612
// Note: this method iterates over the tasks and keys *without* removing
624613
// any entries, so that the keys from aborted tasks can still be
625614
// returned when calling `join_next` in the future.
626-
for (Key { ref key, .. }, task) in &self.tasks_by_key {
615+
for (key, task) in &self.tasks_by_key {
627616
if predicate(key) {
628617
task.abort();
629618
}
@@ -638,7 +627,7 @@ where
638627
/// [`join_next`]: fn@Self::join_next
639628
pub fn keys(&self) -> JoinMapKeys<'_, K, V> {
640629
JoinMapKeys {
641-
iter: self.tasks_by_key.keys(),
630+
iter: self.tasks_by_key.iter(),
642631
_value: PhantomData,
643632
}
644633
}
@@ -666,7 +655,7 @@ where
666655
/// [`join_next`]: fn@Self::join_next
667656
/// [task ID]: tokio::task::Id
668657
pub fn contains_task(&self, task: &Id) -> bool {
669-
self.get_by_id(task).is_some()
658+
self.hashes_by_task.contains_key(task)
670659
}
671660

672661
/// Reserves capacity for at least `additional` more tasks to be spawned
@@ -690,7 +679,9 @@ where
690679
/// ```
691680
#[inline]
692681
pub fn reserve(&mut self, additional: usize) {
693-
self.tasks_by_key.reserve(additional);
682+
let hash_builder = self.hashes_by_task.hasher();
683+
self.tasks_by_key
684+
.reserve(additional, |(k, _)| hash_one(hash_builder, k));
694685
self.hashes_by_task.reserve(additional);
695686
}
696687

@@ -716,7 +707,9 @@ where
716707
#[inline]
717708
pub fn shrink_to_fit(&mut self) {
718709
self.hashes_by_task.shrink_to_fit();
719-
self.tasks_by_key.shrink_to_fit();
710+
let hash_builder = self.hashes_by_task.hasher();
711+
self.tasks_by_key
712+
.shrink_to_fit(|(k, _)| hash_one(hash_builder, k));
720713
}
721714

722715
/// Shrinks the capacity of the map with a lower limit. It will drop
@@ -745,27 +738,20 @@ where
745738
#[inline]
746739
pub fn shrink_to(&mut self, min_capacity: usize) {
747740
self.hashes_by_task.shrink_to(min_capacity);
748-
self.tasks_by_key.shrink_to(min_capacity)
741+
let hash_builder = self.hashes_by_task.hasher();
742+
self.tasks_by_key
743+
.shrink_to(min_capacity, |(k, _)| hash_one(hash_builder, k))
749744
}
750745

751746
/// Look up a task in the map by its key, returning the key and abort handle.
752-
fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<(&'map Key<K>, &'map AbortHandle)>
747+
fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<&'map (K, AbortHandle)>
753748
where
754749
Q: Hash + Eq,
755750
K: Borrow<Q>,
756751
{
757-
let hash = self.hash(key);
758-
self.tasks_by_key
759-
.raw_entry()
760-
.from_hash(hash, |k| k.key.borrow() == key)
761-
}
762-
763-
/// Look up a task in the map by its task ID, returning the key and abort handle.
764-
fn get_by_id<'map>(&'map self, id: &Id) -> Option<(&'map Key<K>, &'map AbortHandle)> {
765-
let hash = self.hashes_by_task.get(id)?;
766-
self.tasks_by_key
767-
.raw_entry()
768-
.from_hash(*hash, |k| &k.id == id)
752+
let hash_builder = self.hashes_by_task.hasher();
753+
let hash = hash_one(hash_builder, key);
754+
self.tasks_by_key.find(hash, |(k, _)| k.borrow() == key)
769755
}
770756

771757
/// Remove a task from the map by ID, returning the key for that task.
@@ -776,28 +762,25 @@ where
776762
// Remove the entry for that hash.
777763
let entry = self
778764
.tasks_by_key
779-
.raw_entry_mut()
780-
.from_hash(hash, |k| k.id == id);
781-
let (Key { id: _key_id, key }, handle) = match entry {
782-
RawEntryMut::Occupied(entry) => entry.remove_entry(),
765+
.find_entry(hash, |(_, abort)| abort.id() == id);
766+
let (key, _) = match entry {
767+
Ok(entry) => entry.remove().0,
783768
_ => return None,
784769
};
785-
debug_assert_eq!(_key_id, id);
786-
debug_assert_eq!(id, handle.id());
787770
self.hashes_by_task.remove(&id);
788771
Some(key)
789772
}
773+
}
790774

791-
/// Returns the hash for a given key.
792-
#[inline]
793-
fn hash<Q: ?Sized>(&self, key: &Q) -> u64
794-
where
795-
Q: Hash,
796-
{
797-
let mut hasher = self.tasks_by_key.hasher().build_hasher();
798-
key.hash(&mut hasher);
799-
hasher.finish()
800-
}
775+
/// Returns the hash for a given key.
776+
#[inline]
777+
fn hash_one<S: BuildHasher, Q: ?Sized>(hash_builder: &S, key: &Q) -> u64
778+
where
779+
Q: Hash,
780+
{
781+
let mut hasher = hash_builder.build_hasher();
782+
key.hash(&mut hasher);
783+
hasher.finish()
801784
}
802785

803786
impl<K, V, S> JoinMap<K, V, S>
@@ -831,11 +814,11 @@ impl<K: fmt::Debug, V, S> fmt::Debug for JoinMap<K, V, S> {
831814
// printing the key and task ID pairs, without format the `Key` struct
832815
// itself or the `AbortHandle`, which would just format the task's ID
833816
// again.
834-
struct KeySet<'a, K: fmt::Debug, S>(&'a HashMap<Key<K>, AbortHandle, S>);
835-
impl<K: fmt::Debug, S> fmt::Debug for KeySet<'_, K, S> {
817+
struct KeySet<'a, K: fmt::Debug>(&'a HashTable<(K, AbortHandle)>);
818+
impl<K: fmt::Debug> fmt::Debug for KeySet<'_, K> {
836819
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
837820
f.debug_map()
838-
.entries(self.0.keys().map(|Key { key, id }| (key, id)))
821+
.entries(self.0.iter().map(|(key, abort)| (key, abort.id())))
839822
.finish()
840823
}
841824
}
@@ -856,31 +839,10 @@ impl<K, V> Default for JoinMap<K, V> {
856839
}
857840
}
858841

859-
// === impl Key ===
860-
861-
impl<K: Hash> Hash for Key<K> {
862-
// Don't include the task ID in the hash.
863-
#[inline]
864-
fn hash<H: Hasher>(&self, hasher: &mut H) {
865-
self.key.hash(hasher);
866-
}
867-
}
868-
869-
// Because we override `Hash` for this type, we must also override the
870-
// `PartialEq` impl, so that all instances with the same hash are equal.
871-
impl<K: PartialEq> PartialEq for Key<K> {
872-
#[inline]
873-
fn eq(&self, other: &Self) -> bool {
874-
self.key == other.key
875-
}
876-
}
877-
878-
impl<K: Eq> Eq for Key<K> {}
879-
880842
/// An iterator over the keys of a [`JoinMap`].
881843
#[derive(Debug, Clone)]
882844
pub struct JoinMapKeys<'a, K, V> {
883-
iter: hashbrown::hash_map::Keys<'a, Key<K>, AbortHandle>,
845+
iter: hashbrown::hash_table::Iter<'a, (K, AbortHandle)>,
884846
/// To make it easier to change `JoinMap` in the future, keep V as a generic
885847
/// parameter.
886848
_value: PhantomData<&'a V>,
@@ -890,7 +852,7 @@ impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> {
890852
type Item = &'a K;
891853

892854
fn next(&mut self) -> Option<&'a K> {
893-
self.iter.next().map(|key| &key.key)
855+
self.iter.next().map(|(key, _)| key)
894856
}
895857

896858
fn size_hint(&self) -> (usize, Option<usize>) {

tokio-util/tests/task_join_map.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#![warn(rust_2018_idioms)]
22
#![cfg(all(feature = "rt", tokio_unstable))]
33

4+
use std::panic::AssertUnwindSafe;
5+
46
use tokio::sync::oneshot;
57
use tokio::time::Duration;
68
use tokio_util::task::JoinMap;
@@ -343,3 +345,34 @@ async fn duplicate_keys2() {
343345

344346
assert!(map.join_next().await.is_none());
345347
}
348+
349+
#[cfg_attr(not(panic = "unwind"), ignore)]
350+
#[tokio::test]
351+
async fn duplicate_keys_drop() {
352+
#[derive(Hash, Debug, PartialEq, Eq)]
353+
struct Key;
354+
impl Drop for Key {
355+
fn drop(&mut self) {
356+
panic!("drop called for key");
357+
}
358+
}
359+
360+
let (send, recv) = oneshot::channel::<()>();
361+
362+
let mut map = JoinMap::new();
363+
364+
map.spawn(Key, async { recv.await.unwrap() });
365+
366+
// replace the task, force it to drop the key and abort the task
367+
// we should expect it to panic when dropping the key.
368+
let _ = std::panic::catch_unwind(AssertUnwindSafe(|| map.spawn(Key, async {}))).unwrap_err();
369+
370+
// don't panic when this key drops.
371+
let (key, _) = map.join_next().await.unwrap();
372+
std::mem::forget(key);
373+
374+
// original task should have been aborted, so the sender should be dangling.
375+
assert!(send.is_closed());
376+
377+
assert!(map.join_next().await.is_none());
378+
}

0 commit comments

Comments
 (0)